From: @limingqi107 Reviewed-by: @cristoval Signed-off-by:tags/v1.1.0
| @@ -46,6 +46,7 @@ constexpr auto kTopKV2 = "TopKV2"; | |||
| constexpr auto kEditDistance = "EditDistance"; | |||
| constexpr auto kGatherD = "GatherD"; | |||
| constexpr auto kIdentity = "Identity"; | |||
| constexpr auto kUpdateCache = "UpdateCache"; | |||
| constexpr auto kCustRunApi = "RunCpuKernel"; | |||
| const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity}; | |||
| @@ -0,0 +1,64 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "hash_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| __global__ void HashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size, | |||
| const int hash_dim) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < index_size; i += blockDim.x * gridDim.x) { | |||
| int hash_index = swap_out_index[i]; | |||
| for (int j = 0; j < hash_dim; j++) { | |||
| swap_out_value[i * hash_dim + j] = hash_table[hash_index * hash_dim + j]; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void HashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size, | |||
| const int hash_dim) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < index_size; i += blockDim.x * gridDim.x) { | |||
| int hash_index = swap_in_index[i]; | |||
| for (int j = 0; j < hash_dim; j++) { | |||
| hash_table[hash_index * hash_dim + j] = swap_in_value[i * hash_dim + j]; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void DoHashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size, | |||
| const int hash_dim, cudaStream_t cuda_stream) { | |||
| HashSwapOut<<<GET_BLOCKS(index_size), GET_THREADS, 0, cuda_stream>>>(hash_table, swap_out_value, swap_out_index, | |||
| index_size, hash_dim); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void DoHashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size, | |||
| const int hash_dim, cudaStream_t cuda_stream) { | |||
| HashSwapIn<<<GET_BLOCKS(index_size), GET_THREADS, 0, cuda_stream>>>(hash_table, swap_in_value, swap_in_index, | |||
| index_size, hash_dim); | |||
| return; | |||
| } | |||
| template void DoHashSwapOut<float>(const float *hash_table, float *swap_out_value, const int *swap_out_index, | |||
| const int index_size, const int hash_dim, cudaStream_t cuda_stream); | |||
| template void DoHashSwapIn<float>(float *hash_table, const float *swap_in_value, const int *swap_in_index, | |||
| const int index_size, const int hash_dim, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_KERNEL_HASH_IMPL_H_ | |||
| #define MINDSPORE_CCSRC_PS_PS_CACHE_KERNEL_HASH_IMPL_H_ | |||
| template <typename T> | |||
| void DoHashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size, | |||
| const int hash_dim, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void DoHashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size, | |||
| const int hash_dim, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_PS_PS_CACHE_KERNEL_HASH_IMPL_H_ | |||
| @@ -116,6 +116,8 @@ using KernelPackPtr = std::shared_ptr<KernelPack>; | |||
| * @brief base class for autotensor kernel and cce kernel. | |||
| */ | |||
| struct Address { | |||
| Address() {} | |||
| Address(void *address_addr, size_t address_size) : addr(address_addr), size(address_size) {} | |||
| void *addr; | |||
| size_t size; | |||
| }; | |||
| @@ -16,5 +16,16 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | |||
| endif () | |||
| if (NOT ENABLE_D) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ascend/ascend_ps_cache.cc") | |||
| endif() | |||
| if (NOT ENABLE_GPU) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") | |||
| endif() | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") | |||
| add_subdirectory(ps_cache) | |||
| set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | |||
| add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) | |||
| @@ -0,0 +1,7 @@ | |||
| if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| file(GLOB_RECURSE _PS_CACHE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps_data/*.cc") | |||
| set_property(SOURCE ${_PS_CACHE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | |||
| add_library(ps_cache SHARED ${_PS_CACHE_SRC_FILES}) | |||
| endif() | |||
| @@ -0,0 +1,253 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/ps_cache/ascend/ascend_ps_cache.h" | |||
| #include <google/protobuf/text_format.h> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ps/ps_cache/ps_cache_factory.h" | |||
| #include "runtime/device/ascend/ascend_memory_pool.h" | |||
| #include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h" | |||
| #include "utils/ms_context.h" | |||
| #include "proto/tensor.pb.h" | |||
| #include "proto/tensor_shape.pb.h" | |||
| #include "proto/attr.pb.h" | |||
| #include "proto/node_def.pb.h" | |||
| using mindspore::kernel::Address; | |||
| using AddressPtr = std::shared_ptr<Address>; | |||
| using AddressPtrList = std::vector<AddressPtr>; | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace ascend { | |||
| MS_REG_PS_CACHE(kAscendDevice, AscendPsCache); | |||
| namespace { | |||
| void SetProtoInputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type, | |||
| mindspore::NodeDef *proto) { | |||
| MS_EXCEPTION_IF_NULL(proto); | |||
| if (data_shape.size() != data_type.size()) { | |||
| MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type."; | |||
| } | |||
| for (size_t input_index = 0; input_index < data_shape.size(); input_index++) { | |||
| ::mindspore::Tensor *proto_inputs = proto->add_inputs(); | |||
| MS_EXCEPTION_IF_NULL(proto_inputs); | |||
| auto input_shape = data_shape[input_index]; | |||
| mindspore::TensorShape *tensorShape = proto_inputs->mutable_tensor_shape(); | |||
| MS_EXCEPTION_IF_NULL(tensorShape); | |||
| for (auto item : input_shape) { | |||
| mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); | |||
| MS_EXCEPTION_IF_NULL(dim); | |||
| dim->set_size((::google::protobuf::int64)item); | |||
| } | |||
| auto input_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[input_index]); | |||
| proto_inputs->set_tensor_type(input_type); | |||
| proto_inputs->set_mem_device("HBM"); | |||
| } | |||
| } | |||
| void SetProtoOutputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type, | |||
| mindspore::NodeDef *proto) { | |||
| MS_EXCEPTION_IF_NULL(proto); | |||
| if (data_shape.size() != data_type.size()) { | |||
| MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type."; | |||
| } | |||
| for (size_t output_index = 0; output_index < data_shape.size(); output_index++) { | |||
| ::mindspore::Tensor *proto_outputs = proto->add_outputs(); | |||
| MS_EXCEPTION_IF_NULL(proto_outputs); | |||
| auto output_shape = data_shape[output_index]; | |||
| mindspore::TensorShape *tensorShape = proto_outputs->mutable_tensor_shape(); | |||
| MS_EXCEPTION_IF_NULL(tensorShape); | |||
| for (auto item : output_shape) { | |||
| mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); | |||
| MS_EXCEPTION_IF_NULL(dim); | |||
| dim->set_size((::google::protobuf::int64)item); | |||
| } | |||
| auto output_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[output_index]); | |||
| proto_outputs->set_tensor_type(output_type); | |||
| proto_outputs->set_mem_device("HBM"); | |||
| } | |||
| } | |||
| void SetNodedefProto(const std::shared_ptr<KernelNodeInfo> &op_info, | |||
| const std::shared_ptr<kernel::AicpuOpKernelMod> &kernel_mod_ptr) { | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod_ptr); | |||
| mindspore::NodeDef proto; | |||
| proto.set_op(op_info->op_name_); | |||
| SetProtoInputs(op_info->input_data_shape_, op_info->input_data_type_, &proto); | |||
| SetProtoOutputs(op_info->output_data_shape_, op_info->output_data_type_, &proto); | |||
| std::string nodeDefStr; | |||
| if (!proto.SerializeToString(&nodeDefStr)) { | |||
| MS_LOG(EXCEPTION) << "Serialize nodeDef to string failed."; | |||
| } | |||
| MS_LOG(DEBUG) << "Set node def proto, node name:" << op_info->op_name_; | |||
| kernel_mod_ptr->SetNodeDef(nodeDefStr); | |||
| } | |||
| } // namespace | |||
| void AscendPsCache::InitDevice(uint32_t device_id, const void *context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto ret = rtSetDevice(device_id); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << ret << "]"; | |||
| } | |||
| auto rt_context = const_cast<rtContext_t>(context); | |||
| ret = rtCtxSetCurrent(rt_context); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; | |||
| } | |||
| ret = rtStreamCreate(&stream_, 0); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtStreamCreate, ret[" << ret << "]"; | |||
| } | |||
| } | |||
| void *AscendPsCache::MallocMemory(size_t size) { | |||
| return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size); | |||
| } | |||
| void AscendPsCache::MallocConstantMemory(size_t constant_value) { | |||
| offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); | |||
| MS_EXCEPTION_IF_NULL(offset_addr_); | |||
| rtMemset(offset_addr_, sizeof(int), 0, sizeof(int)); | |||
| cache_vocab_size_addr_ = | |||
| reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); | |||
| MS_EXCEPTION_IF_NULL(cache_vocab_size_addr_); | |||
| rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int)); | |||
| } | |||
| void AscendPsCache::RecordEvent() { | |||
| event_.reset(new rtEvent_t()); | |||
| auto ret = rtEventCreate(&(*event_)); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Create event failed"; | |||
| } | |||
| ret = rtEventRecord(*event_, stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Record event failed"; | |||
| } | |||
| } | |||
| void AscendPsCache::SynchronizeEvent() { | |||
| auto ret = rtEventSynchronize(*event_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "tEventSynchronize failed"; | |||
| } | |||
| ret = rtEventDestroy(*event_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "rtEventDestroy failed"; | |||
| } | |||
| } | |||
| void AscendPsCache::SynchronizeStream() { | |||
| auto ret = rtStreamSynchronize(stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "rtStreamSynchronize failed"; | |||
| } | |||
| } | |||
| void AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(dst); | |||
| MS_EXCEPTION_IF_NULL(src); | |||
| auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed"; | |||
| } | |||
| } | |||
| void AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(dst); | |||
| MS_EXCEPTION_IF_NULL(src); | |||
| auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed"; | |||
| } | |||
| } | |||
| void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | |||
| size_t hash_table_size, size_t embedding_size, size_t swap_out_size) { | |||
| MS_EXCEPTION_IF_NULL(hash_table_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_out_value_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_out_index_addr); | |||
| auto hash_swap_out_mod = std::make_shared<kernel::AicpuOpKernelMod>(); | |||
| MS_EXCEPTION_IF_NULL(hash_swap_out_mod); | |||
| hash_swap_out_mod->SetNodeName(kEmbeddingLookupOpName); | |||
| std::vector<std::vector<size_t>> input_shape; | |||
| std::vector<std::vector<size_t>> output_shape; | |||
| std::vector<TypeId> input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}; | |||
| std::vector<TypeId> output_type = {TypeId::kNumberTypeFloat32}; | |||
| input_shape.push_back({hash_table_size, embedding_size}); | |||
| input_shape.push_back({swap_out_size}); | |||
| input_shape.push_back({1}); | |||
| output_shape.push_back({swap_out_size, embedding_size}); | |||
| auto op_info = | |||
| std::make_shared<KernelNodeInfo>(kEmbeddingLookupOpName, input_shape, input_type, output_shape, output_type); | |||
| SetNodedefProto(op_info, hash_swap_out_mod); | |||
| AddressPtrList kernel_inputs; | |||
| AddressPtrList kernel_outputs = { | |||
| std::make_shared<Address>(swap_out_value_addr, swap_out_size * embedding_size * sizeof(float))}; | |||
| AddressPtrList kernel_workspaces; | |||
| kernel_inputs.push_back(std::make_shared<Address>(hash_table_addr, hash_table_size * embedding_size * sizeof(float))); | |||
| kernel_inputs.push_back(std::make_shared<Address>(swap_out_index_addr, swap_out_size * sizeof(int))); | |||
| kernel_inputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int))); | |||
| auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Hash swap out launch failed."; | |||
| } | |||
| } | |||
| void AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | |||
| size_t hash_table_size, size_t embedding_size, size_t swap_in_size) { | |||
| MS_EXCEPTION_IF_NULL(hash_table_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_in_value_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_in_index_addr); | |||
| auto hash_swap_in_mod = std::make_shared<kernel::AicpuOpKernelMod>(); | |||
| MS_EXCEPTION_IF_NULL(hash_swap_in_mod); | |||
| hash_swap_in_mod->SetNodeName(kernel::kUpdateCache); | |||
| std::vector<std::vector<size_t>> input_shape; | |||
| std::vector<std::vector<size_t>> output_shape; | |||
| std::vector<TypeId> input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32, | |||
| TypeId::kNumberTypeInt32}; | |||
| std::vector<TypeId> output_type = {TypeId::kNumberTypeInt32}; | |||
| input_shape.push_back({hash_table_size, embedding_size}); | |||
| input_shape.push_back({swap_in_size}); | |||
| input_shape.push_back({swap_in_size, embedding_size}); | |||
| input_shape.push_back({1}); | |||
| output_shape.push_back({1}); | |||
| auto op_info = | |||
| std::make_shared<KernelNodeInfo>(kernel::kUpdateCache, input_shape, input_type, output_shape, output_type); | |||
| SetNodedefProto(op_info, hash_swap_in_mod); | |||
| AddressPtrList kernel_inputs; | |||
| AddressPtrList kernel_outputs; | |||
| AddressPtrList kernel_workspaces; | |||
| kernel_inputs.push_back(std::make_shared<Address>(hash_table_addr, hash_table_size * embedding_size * sizeof(float))); | |||
| kernel_inputs.push_back(std::make_shared<Address>(swap_in_index_addr, swap_in_size * sizeof(int))); | |||
| kernel_inputs.push_back(std::make_shared<Address>(swap_in_value_addr, swap_in_size * embedding_size * sizeof(float))); | |||
| kernel_inputs.push_back(std::make_shared<Address>(cache_vocab_size_addr_, sizeof(int))); | |||
| // The output of updateCache kernel is required but not useful, so any address can be assigned. | |||
| kernel_outputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int))); | |||
| auto ret = hash_swap_in_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Hash swap in launch failed."; | |||
| } | |||
| } | |||
| } // namespace ascend | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,73 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_ASCEND_ASCEND_PS_CACHE_H_ | |||
| #define MINDSPORE_CCSRC_PS_PS_CACHE_ASCEND_ASCEND_PS_CACHE_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "ps/ps_cache/ps_cache_basic.h" | |||
| #include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h" | |||
| #include "ir/dtype.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace ascend { | |||
| struct KernelNodeInfo { | |||
| KernelNodeInfo(const std::string &op_name, std::vector<std::vector<size_t>> input_data_shape, | |||
| std::vector<TypeId> input_data_type, std::vector<std::vector<size_t>> output_data_shape, | |||
| std::vector<TypeId> output_data_type) | |||
| : op_name_(op_name) { | |||
| input_data_shape_.swap(input_data_shape); | |||
| input_data_type_.swap(input_data_type); | |||
| output_data_shape_.swap(output_data_shape); | |||
| output_data_type_.swap(output_data_type); | |||
| } | |||
| std::string op_name_; | |||
| std::vector<std::vector<size_t>> input_data_shape_; | |||
| std::vector<TypeId> input_data_type_; | |||
| std::vector<std::vector<size_t>> output_data_shape_; | |||
| std::vector<TypeId> output_data_type_; | |||
| }; | |||
| class AscendPsCache : public PsCacheBasic { | |||
| public: | |||
| AscendPsCache() = default; | |||
| ~AscendPsCache() override = default; | |||
| void InitDevice(uint32_t device_id, const void *context) override; | |||
| void *MallocMemory(size_t size) override; | |||
| void MallocConstantMemory(size_t constant_value) override; | |||
| void RecordEvent() override; | |||
| void SynchronizeEvent() override; | |||
| void SynchronizeStream() override; | |||
| void CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||
| void CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||
| void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, | |||
| size_t embedding_size, size_t swap_out_size) override; | |||
| void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, | |||
| size_t embedding_size, size_t swap_in_size) override; | |||
| private: | |||
| int *offset_addr_{nullptr}; | |||
| int *cache_vocab_size_addr_{nullptr}; | |||
| std::unique_ptr<rtEvent_t> event_; | |||
| }; | |||
| } // namespace ascend | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_PS_CACHE_ASCEND_ASCEND_PS_CACHE_H_ | |||
| @@ -0,0 +1,92 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/ps_cache/gpu/gpu_ps_cache.h" | |||
| #include "ps/ps_cache/ps_cache_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/hash_impl.cuh" | |||
| #include "runtime/device/gpu/gpu_common.h" | |||
| #include "runtime/device/gpu/gpu_memory_allocator.h" | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace gpu { | |||
| MS_REG_PS_CACHE(kGPUDevice, GPUPsCache); | |||
| void GPUPsCache::InitDevice(uint32_t device_id, const void *) { | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed") | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)), | |||
| "Cuda create stream failed"); | |||
| } | |||
| void *GPUPsCache::MallocMemory(size_t size) { | |||
| return device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(size); | |||
| } | |||
| void GPUPsCache::RecordEvent() { | |||
| event_.reset(new cudaEvent_t()); | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)), | |||
| "Cuda record event failed"); | |||
| } | |||
| void GPUPsCache::SynchronizeEvent() { | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed"); | |||
| } | |||
| void GPUPsCache::SynchronizeStream() { | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)), | |||
| "Cuda sync stream failed"); | |||
| } | |||
| void GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(dst); | |||
| MS_EXCEPTION_IF_NULL(src); | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( | |||
| cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_)), | |||
| "Cuda memcpy failed"); | |||
| } | |||
| void GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(dst); | |||
| MS_EXCEPTION_IF_NULL(src); | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( | |||
| cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_)), | |||
| "Cuda memcpy failed"); | |||
| } | |||
| void GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t, | |||
| size_t embedding_size, size_t swap_out_size) { | |||
| MS_EXCEPTION_IF_NULL(hash_table_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_out_value_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_out_index_addr); | |||
| DoHashSwapOut(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_out_value_addr), | |||
| reinterpret_cast<int *>(swap_out_index_addr), swap_out_size, embedding_size, | |||
| reinterpret_cast<cudaStream_t>(stream_)); | |||
| } | |||
| void GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t, | |||
| size_t embedding_size, size_t swap_in_size) { | |||
| MS_EXCEPTION_IF_NULL(hash_table_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_in_value_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_in_index_addr); | |||
| DoHashSwapIn(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_in_value_addr), | |||
| reinterpret_cast<int *>(swap_in_index_addr), swap_in_size, embedding_size, | |||
| reinterpret_cast<cudaStream_t>(stream_)); | |||
| } | |||
| } // namespace gpu | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_PS_CACHE_H_ | |||
| #define MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_PS_CACHE_H_ | |||
| #include <cuda_runtime_api.h> | |||
| #include <memory> | |||
| #include "ps/ps_cache/ps_cache_basic.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace gpu { | |||
| class GPUPsCache : public PsCacheBasic { | |||
| public: | |||
| GPUPsCache() = default; | |||
| ~GPUPsCache() override = default; | |||
| void InitDevice(uint32_t device_id, const void *context) override; | |||
| void *MallocMemory(size_t size) override; | |||
| void RecordEvent() override; | |||
| void SynchronizeEvent() override; | |||
| void SynchronizeStream() override; | |||
| void CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||
| void CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||
| void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, | |||
| size_t embedding_size, size_t swap_out_size) override; | |||
| void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, | |||
| size_t embedding_size, size_t swap_in_size) override; | |||
| private: | |||
| std::unique_ptr<cudaEvent_t> event_; | |||
| }; | |||
| } // namespace gpu | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_PS_CACHE_H_ | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_BASIC_H | |||
| #define MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_BASIC_H | |||
| #include <utility> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| class PsCacheBasic { | |||
| public: | |||
| PsCacheBasic() = default; | |||
| virtual ~PsCacheBasic() = default; | |||
| virtual void InitDevice(uint32_t device_id, const void *context) = 0; | |||
| virtual void *MallocMemory(size_t size) = 0; | |||
| virtual void MallocConstantMemory(size_t constant_value) {} | |||
| virtual void RecordEvent() = 0; | |||
| virtual void SynchronizeEvent() = 0; | |||
| virtual void SynchronizeStream() = 0; | |||
| virtual void CopyHostMemToDevice(void *dst, void *src, size_t size) = 0; | |||
| virtual void CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0; | |||
| virtual void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | |||
| size_t hash_table_size, size_t embedding_size, size_t swap_out_size) = 0; | |||
| virtual void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | |||
| size_t hash_table_size, size_t embedding_size, size_t swap_in_size) = 0; | |||
| protected: | |||
| void *stream_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_BASIC_H | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/ps_cache/ps_cache_factory.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| PsCacheFactory &PsCacheFactory::Get() { | |||
| static PsCacheFactory instance; | |||
| return instance; | |||
| } | |||
| void PsCacheFactory::Register(const std::string &device_name, PsCacheCreator &&ps_cache_creator) { | |||
| if (ps_cache_creators_.end() == ps_cache_creators_.find(device_name)) { | |||
| (void)ps_cache_creators_.emplace(device_name, ps_cache_creator); | |||
| } | |||
| } | |||
| std::shared_ptr<PsCacheBasic> PsCacheFactory::ps_cache(const std::string &device_name) { | |||
| auto iter = ps_cache_creators_.find(device_name); | |||
| if (ps_cache_creators_.end() != iter) { | |||
| MS_EXCEPTION_IF_NULL(iter->second); | |||
| return (iter->second)(); | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_FACTORY_H_ | |||
| #define MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_FACTORY_H_ | |||
| #include <functional> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "ps/ps_cache/ps_cache_basic.h" | |||
| #include "utils/ms_utils.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| using PsCacheCreator = std::function<std::shared_ptr<PsCacheBasic>()>; | |||
| class PsCacheFactory { | |||
| public: | |||
| static PsCacheFactory &Get(); | |||
| void Register(const std::string &device_name, PsCacheCreator &&ps_cache_creator); | |||
| std::shared_ptr<PsCacheBasic> ps_cache(const std::string &device_name); | |||
| private: | |||
| PsCacheFactory() = default; | |||
| ~PsCacheFactory() = default; | |||
| DISABLE_COPY_AND_ASSIGN(PsCacheFactory) | |||
| std::map<std::string, PsCacheCreator> ps_cache_creators_; | |||
| }; | |||
| class PsCacheRegistrar { | |||
| public: | |||
| PsCacheRegistrar(const std::string &device_name, PsCacheCreator &&ps_cache_creator) { | |||
| PsCacheFactory::Get().Register(device_name, std::move(ps_cache_creator)); | |||
| } | |||
| ~PsCacheRegistrar() = default; | |||
| }; | |||
| #define MS_REG_PS_CACHE(DEVICE_NAME, PS_CACHE_CLASS) \ | |||
| static const PsCacheRegistrar g_ps_cache_registrar__##DEVICE_NAME##_##_reg( \ | |||
| DEVICE_NAME, []() { return std::make_shared<PS_CACHE_CLASS>(); }); | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_FACTORY_H_ | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/ps_cache/ps_data/ps_data_channel.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| void PsDataChannel::TryLockChannel() { | |||
| // The prefetch order of data needs to be consistent with the graph execution order. | |||
| // Example: if graph execution order is graph1 --> graph2 --> graph1 -->graph2, | |||
| // then the data prefetch order needs be channel1 --> channel 2 --> channel1 --> channel2. | |||
| if ((current_data_step_ != 0) && (current_data_step_ % step_num_ == 0)) { | |||
| MS_LOG(INFO) << "Lock channel:" << channel_name_; | |||
| std::unique_lock<std::mutex> locker(channel_mutex_); | |||
| channel_.wait(locker, [this] { return channel_open_; }); | |||
| channel_open_ = false; | |||
| } | |||
| current_data_step_++; | |||
| } | |||
| void PsDataChannel::TryWakeChannel() { | |||
| if ((current_graph_step_ != 0) && (current_graph_step_ % step_num_ == 0)) { | |||
| MS_LOG(INFO) << "Wake up channel:" << channel_name_; | |||
| std::lock_guard<std::mutex> locker(channel_mutex_); | |||
| channel_open_ = true; | |||
| channel_.notify_one(); | |||
| } | |||
| current_graph_step_++; | |||
| } | |||
| void PsDataChannel::set_data(void *data, const size_t data_size) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| TryLockChannel(); | |||
| data_ = data; | |||
| data_size_ = data_size; | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_CHANNEL_H_ | |||
| #define MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_CHANNEL_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <condition_variable> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| class PsDataChannel { | |||
| public: | |||
| PsDataChannel(const std::string &channel_name, size_t step_num) | |||
| : channel_name_(channel_name), | |||
| step_num_(step_num), | |||
| current_data_step_(0), | |||
| current_graph_step_(0), | |||
| channel_open_(false), | |||
| data_(nullptr), | |||
| data_size_(0) {} | |||
| virtual ~PsDataChannel() = default; | |||
| void set_data(void *data, const size_t data_size); | |||
| void *data() const { return data_; } | |||
| size_t data_size() const { return data_size_; } | |||
| void ResetData() { data_ = nullptr; } | |||
| void set_step_num(size_t step_num) { step_num_ = step_num; } | |||
| void TryWakeChannel(); | |||
| private: | |||
| void TryLockChannel(); | |||
| std::string channel_name_; | |||
| // The step num of each epoch. | |||
| size_t step_num_; | |||
| size_t current_data_step_; | |||
| size_t current_graph_step_; | |||
| bool channel_open_; | |||
| std::mutex channel_mutex_; | |||
| std::condition_variable channel_; | |||
| void *data_; | |||
| size_t data_size_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_CHANNEL_H_ | |||
| @@ -53,6 +53,15 @@ namespace gpu { | |||
| } \ | |||
| } | |||
| #define CHECK_CUDA_RET_WITH_ERROR_NOTRACE(expression, message) \ | |||
| { \ | |||
| cudaError_t status = (expression); \ | |||
| if (status != cudaSuccess) { \ | |||
| MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << status << " " \ | |||
| << cudaGetErrorString(status); \ | |||
| } \ | |||
| } | |||
| #define CHECK_CUDA_RET_WITH_EXCEPT(node, expression, message) \ | |||
| { \ | |||
| cudaError_t status = (expression); \ | |||
| @@ -62,6 +71,15 @@ namespace gpu { | |||
| } \ | |||
| } | |||
| #define CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(expression, message) \ | |||
| { \ | |||
| cudaError_t status = (expression); \ | |||
| if (status != cudaSuccess) { \ | |||
| MS_LOG(EXCEPTION) << "CUDA Error: " << message << " | Error Number: " << status << " " \ | |||
| << cudaGetErrorString(status); \ | |||
| } \ | |||
| } | |||
| #define CHECK_CUDNN_RET_WITH_EXCEPT(node, expression, message) \ | |||
| { \ | |||
| cudnnStatus_t status = (expression); \ | |||
| @@ -139,6 +139,8 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/util.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc") | |||