| @@ -34,6 +34,7 @@ const size_t kAxis_H = 2; | |||
| const size_t kAxis_W = 3; | |||
| const size_t kAxis_6HD_H = 1; | |||
| const size_t kAxis_6HD_W = 2; | |||
| const int64_t kAxisDim = 4; | |||
| const std::map<std::string, ConvertFunction> kReduceConvertMap = {{kOpFormat_FRAC_Z, ConvertReduceAttrFraczAnd6HD}, | |||
| {kOpFormat_C1HWNCoC0, ConvertReduceAttrFraczAnd6HD}}; | |||
| void SafeCheckFunction(const CNodePtr &cnode, const std::vector<int64_t> &reduce_axis) { | |||
| @@ -46,7 +47,7 @@ void SafeCheckFunction(const CNodePtr &cnode, const std::vector<int64_t> &reduce | |||
| << "] is not single input or single output "; | |||
| } | |||
| for (auto elem : reduce_axis) { | |||
| if (elem > 4) { | |||
| if (elem > kAxisDim) { | |||
| MS_LOG(INFO) << "reduce axis is larger than 4 dims reduce axis : [" << elem << "]"; | |||
| } | |||
| } | |||
| @@ -104,6 +104,7 @@ bool CheckOtherOutputs(const CNodePtr &node, const std::shared_ptr<kernel::Kerne | |||
| } | |||
| bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptr<kernel::KernelBuildInfo> &kernel_info, size_t index) { | |||
| constexpr size_t kInferShapeSize = 4; | |||
| if (kernel_info == nullptr) { | |||
| return false; | |||
| } | |||
| @@ -111,8 +112,8 @@ bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptr<kernel::Kernel | |||
| if (AnfAlgo::GetOutputDeviceDataType(node, 0) != kernel_info->GetOutputDeviceType(index)) { | |||
| return false; | |||
| } | |||
| if (AnfAlgo::GetOutputInferShape(node, 0).size() == 4 && AnfAlgo::GetOutputFormat(node, 0) == kOpFormat_NCHW && | |||
| kernel_info->GetOutputFormat(index) == kOpFormat_DEFAULT) { | |||
| if (AnfAlgo::GetOutputInferShape(node, 0).size() == kInferShapeSize && | |||
| AnfAlgo::GetOutputFormat(node, 0) == kOpFormat_NCHW && kernel_info->GetOutputFormat(index) == kOpFormat_DEFAULT) { | |||
| return true; | |||
| } | |||
| return AnfAlgo::GetOutputFormat(node, 0) == kernel_info->GetOutputFormat(index); | |||
| @@ -102,12 +102,13 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_ | |||
| } | |||
| std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string, size_t> &format_counter) const { | |||
| constexpr size_t kFormatCount = 2; | |||
| std::string convert_format = kOpFormat_DEFAULT; | |||
| size_t counter = 0; | |||
| if (format_counter.size() > 2) { | |||
| if (format_counter.size() > kFormatCount) { | |||
| return kOpFormat_DEFAULT; | |||
| } | |||
| if (format_counter.size() == 2 && format_counter.find(kOpFormat_DEFAULT) == format_counter.end()) { | |||
| if (format_counter.size() == kFormatCount && format_counter.find(kOpFormat_DEFAULT) == format_counter.end()) { | |||
| return kOpFormat_DEFAULT; | |||
| } | |||
| for (const auto &iter : format_counter) { | |||
| @@ -32,14 +32,17 @@ constexpr size_t kMaxPoolInputNum = 2; | |||
| constexpr size_t kMaxPoolAttrAxisNum = 4; | |||
| constexpr size_t kMaxPoolGradInputNum = 4; | |||
| constexpr size_t kMaxPoolWithArgmaxOutputNum = 2; | |||
| constexpr size_t kIndex1 = 1; | |||
| constexpr size_t kIndex2 = 2; | |||
| constexpr size_t kIndex3 = 3; | |||
| CNodePtr GetMaxPool(const CNodePtr &maxpool_grad) { | |||
| MS_EXCEPTION_IF_NULL(maxpool_grad); | |||
| if (maxpool_grad->inputs().size() != kMaxPoolGradInputNum) { | |||
| MS_LOG(EXCEPTION) << "MaxPoolGrad's input number should be " << kMaxPoolGradInputNum - 1 << ", but got " | |||
| << maxpool_grad->inputs().size() - 1; | |||
| MS_LOG(EXCEPTION) << "MaxPoolGrad's input number should be " << (kMaxPoolGradInputNum - 1) << ", but got " | |||
| << (maxpool_grad->inputs().size() - 1); | |||
| } | |||
| auto maxpool_anf = maxpool_grad->input(2); | |||
| auto maxpool_anf = maxpool_grad->input(kIndex2); | |||
| MS_EXCEPTION_IF_NULL(maxpool_anf); | |||
| return maxpool_anf->cast<CNodePtr>(); | |||
| } | |||
| @@ -48,8 +51,8 @@ CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxp | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(maxpool); | |||
| if (maxpool->inputs().size() != kMaxPoolInputNum) { | |||
| MS_LOG(EXCEPTION) << "MaxPool's input number should be " << kMaxPoolInputNum - 1 << ", but got " | |||
| << maxpool->inputs().size() - 1; | |||
| MS_LOG(EXCEPTION) << "MaxPool's input number should be " << (kMaxPoolInputNum - 1) << ", but got " | |||
| << (maxpool->inputs().size() - 1); | |||
| } | |||
| std::vector<AnfNodePtr> maxpool_argmax_inputs = {NewValueNode(std::make_shared<Primitive>(kMaxPoolWithArgmaxOpName)), | |||
| maxpool->input(1)}; | |||
| @@ -71,14 +74,14 @@ CNodePtr CreateMaxPoolGradWithArgmax(const FuncGraphPtr &graph, const CNodePtr & | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(maxpool_grad); | |||
| if (maxpool_grad->inputs().size() != kMaxPoolGradInputNum) { | |||
| MS_LOG(EXCEPTION) << "MaxPoolGrad's input number should be " << kMaxPoolGradInputNum - 1 << ", but got " | |||
| << maxpool_grad->inputs().size() - 1; | |||
| MS_LOG(EXCEPTION) << "MaxPoolGrad's input number should be " << (kMaxPoolGradInputNum - 1) << ", but got " | |||
| << (maxpool_grad->inputs().size() - 1); | |||
| } | |||
| // MaxPoolGrad's inputs are {input, output, grad_input}, MaxPoolGradWithArgmax's inputs are | |||
| // {input, grad_input, argmax_output} | |||
| std::vector<AnfNodePtr> maxpool_grad_argmax_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kMaxPoolGradWithArgmaxOpName)), maxpool_grad->input(1), | |||
| maxpool_grad->input(3), maxpool_argmax_outputs[1]}; | |||
| maxpool_grad->input(kIndex3), maxpool_argmax_outputs[1]}; | |||
| auto maxpool_grad_argmax = graph->NewCNode(maxpool_grad_argmax_inputs); | |||
| MS_EXCEPTION_IF_NULL(maxpool_grad_argmax); | |||
| maxpool_grad_argmax->set_scope(maxpool_grad->scope()); | |||
| @@ -99,12 +102,13 @@ void SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad, const C | |||
| << ksize.size(); | |||
| } | |||
| // note that strides and ksize change from (1, 1, x, y) to (1, x, y, 1) | |||
| for (size_t i = 1; i <= 2; ++i) { | |||
| strides[i] = strides[i + 1]; | |||
| ksize[i] = ksize[i + 1]; | |||
| } | |||
| strides[3] = 1; | |||
| ksize[3] = 1; | |||
| strides[kIndex1] = strides[kIndex2]; | |||
| strides[kIndex2] = strides[kIndex3]; | |||
| strides[kIndex3] = 1; | |||
| ksize[kIndex1] = ksize[kIndex2]; | |||
| ksize[kIndex2] = ksize[kIndex3]; | |||
| ksize[kIndex3] = 1; | |||
| AnfAlgo::CopyNodeAttrs(maxpool, maxpool_argmax); | |||
| AnfAlgo::CopyNodeAttrs(maxpool_grad, maxpool_grad_argmax); | |||
| @@ -27,6 +27,13 @@ namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kMaxPoolGradWithArgmaxInputNum = 4; | |||
| constexpr size_t kMaxPoolWithArgmaxShape = 4; | |||
| constexpr size_t kAlignBytes = 16; | |||
| constexpr size_t kIndex1 = 1; | |||
| constexpr size_t kIndex2 = 2; | |||
| constexpr size_t kIndex3 = 3; | |||
| constexpr size_t kIndex4 = 4; | |||
| bool IsC(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| AnfNodePtr in = utils::cast<AnfNodePtr>(n); | |||
| @@ -41,7 +48,7 @@ CNodePtr GetMaxPoolWithArgmax(const CNodePtr &maxpool_grad_with_argmax) { | |||
| if (maxpool_grad_with_argmax->inputs().size() != kMaxPoolGradWithArgmaxInputNum) { | |||
| MS_LOG(EXCEPTION) << "MaxPoolGradWithArgmax has wrong input size."; | |||
| } | |||
| auto tuple_getitem0_anf = maxpool_grad_with_argmax->input(3); | |||
| auto tuple_getitem0_anf = maxpool_grad_with_argmax->input(kIndex3); | |||
| MS_EXCEPTION_IF_NULL(tuple_getitem0_anf); | |||
| return tuple_getitem0_anf->cast<CNodePtr>(); | |||
| } | |||
| @@ -64,11 +71,11 @@ const AnfNodePtr MaxPoolWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph | |||
| auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool_with_argmax, kAttrKernelSize); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(maxpool_with_argmax, 0); | |||
| auto argmax_shape = output_shape; | |||
| if (argmax_shape.size() != 4) { | |||
| if (argmax_shape.size() != kMaxPoolWithArgmaxShape) { | |||
| MS_LOG(DEBUG) << "argmax's infer shape size not equal 4"; | |||
| } | |||
| argmax_shape[2] = ksize[1] * ksize[2]; | |||
| argmax_shape[3] = (output_shape[2] * output_shape[3] + 15) / 16 + 1; | |||
| argmax_shape[kIndex2] = ksize[kIndex1] * ksize[kIndex2]; | |||
| argmax_shape[kIndex3] = (output_shape[kIndex2] * output_shape[kIndex3] + kAlignBytes - 1) / kAlignBytes + 1; | |||
| auto types = {AnfAlgo::GetOutputInferDataType(maxpool_with_argmax, 0), argmax_dtype}; | |||
| auto shapes = {output_shape, argmax_shape}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_with_argmax.get()); | |||
| @@ -98,11 +105,11 @@ const AnfNodePtr MaxPoolGradWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &g | |||
| TypeId argmax_dtype = kNumberTypeUInt16; | |||
| auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool_grad_with_argmax, kAttrKernelSize); | |||
| auto argmax_shape = AnfAlgo::GetOutputInferShape(tuple_getitem0_anf, 0); | |||
| if (argmax_shape.size() != 4) { | |||
| if (argmax_shape.size() != kMaxPoolWithArgmaxShape) { | |||
| MS_LOG(DEBUG) << "argmax's infer shape size not equal 4"; | |||
| } | |||
| argmax_shape[3] = (argmax_shape[2] * argmax_shape[3] + 15) / 16 + 1; | |||
| argmax_shape[2] = ksize[1] * ksize[2]; | |||
| argmax_shape[kIndex3] = (argmax_shape[kIndex2] * argmax_shape[kIndex3] + kAlignBytes - 1) / kAlignBytes + 1; | |||
| argmax_shape[kIndex2] = ksize[kIndex1] * ksize[kIndex2]; | |||
| AnfAlgo::SetOutputInferTypeAndShape({argmax_dtype}, {argmax_shape}, tuple_getitem0_anf.get()); | |||
| return maxpool_grad_with_argmax; | |||
| @@ -93,7 +93,7 @@ namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| const int FLOAT_LEN = sizeof(float); | |||
| const int FLOAT16_LEN = 2; // sizeof(float16); | |||
| const int FLOAT16_LEN = 2; | |||
| const std::set<std::string> kOpNeedTransFormat = { | |||
| kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, | |||
| kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; | |||
| @@ -82,8 +82,8 @@ std::string GetRankId() { | |||
| MS_EXCEPTION_IF_NULL(mpi_config_ptr); | |||
| if (mpi_config_ptr->enable_mpi()) { | |||
| int rank_id = GetMPIRankId(); | |||
| const char *offset = std::getenv("RANK_OFFSET"); | |||
| if (offset != nullptr) { | |||
| const std::string offset = std::getenv("RANK_OFFSET"); | |||
| if (offset.empty()) { | |||
| try { | |||
| int rank_offset = std::stoi(offset); | |||
| rank_id += rank_offset; | |||
| @@ -578,7 +578,8 @@ void AscendKernelRuntime::DumpTaskExceptionInfo(const session::KernelGraph *grap | |||
| } | |||
| } | |||
| bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { | |||
| bool AscendKernelRuntime::Run(session::KernelGraph *const graph, bool is_task_sink) { | |||
| const uint64_t kUSecondInSecond = 1000000; | |||
| SignalGuard sg; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| bool ret = false; | |||
| @@ -596,11 +597,10 @@ bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { | |||
| } | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| auto end_time = std::chrono::steady_clock::now(); | |||
| std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time; | |||
| std::chrono::duration<double, std::ratio<1, kUSecondInSecond>> cost = end_time - start_time; | |||
| MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us"; | |||
| #else | |||
| (void)gettimeofday(&end_time, nullptr); | |||
| const uint64_t kUSecondInSecond = 1000000; | |||
| uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec); | |||
| cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec); | |||
| MS_LOG(INFO) << "Call MS Run Success in " << cost << " us"; | |||
| @@ -779,8 +779,9 @@ bool AscendKernelRuntime::InitDevice() { | |||
| bool AscendKernelRuntime::ResetDevice(uint32_t device_id) { | |||
| SetCurrentContext(); | |||
| int32_t ret; | |||
| if (stream_ != nullptr) { | |||
| auto ret = rtStreamDestroy(stream_); | |||
| ret = rtStreamDestroy(stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]"; | |||
| } | |||
| @@ -788,14 +789,14 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) { | |||
| } | |||
| if (communication_stream_ != nullptr) { | |||
| auto ret = rtStreamDestroy(communication_stream_); | |||
| ret = rtStreamDestroy(communication_stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]"; | |||
| } | |||
| communication_stream_ = nullptr; | |||
| } | |||
| auto ret = rtDeviceReset(device_id); | |||
| ret = rtDeviceReset(device_id); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]"; | |||
| } | |||
| @@ -20,12 +20,12 @@ | |||
| #include "runtime/device/ascend/ascend_label_assign.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| static constexpr uint32_t kLabelGotoLabelId = 1; | |||
| static constexpr uint32_t kLabelSwitchLabelId = 2; | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| static constexpr uint32_t kLabelGotoLabelId = 1; | |||
| static constexpr uint32_t kLabelSwitchLabelId = 2; | |||
| static void UpdateLabelGoto(NotNull<CNodePtr> node) { | |||
| if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { | |||
| return; | |||
| @@ -142,7 +142,7 @@ uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> gr | |||
| std::lock_guard<std::mutex> lock(label_num_mutex_); | |||
| auto iter = label_num_.find(graph.get()); | |||
| if (iter == label_num_.end()) { | |||
| MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, defalut is 0."; | |||
| MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, default is 0."; | |||
| return 0; | |||
| } | |||
| return iter->second; | |||
| @@ -132,6 +132,8 @@ bool KernelAdjust::ExistIndependent(const std::shared_ptr<session::KernelGraph> | |||
| void KernelAdjust::InsertIndepentParallel(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | |||
| const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input, | |||
| std::vector<CNodePtr> *exec_order) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | |||
| MS_EXCEPTION_IF_NULL(exec_order); | |||
| device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); | |||
| CNodePtr independent_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kIndependentStreamSwitch); | |||
| MS_EXCEPTION_IF_NULL(independent_switch_app); | |||
| @@ -40,6 +40,8 @@ using mindspore::kernel::AddressPtr; | |||
| namespace mindspore { | |||
| namespace device { | |||
| constexpr size_t kMinInputSize = 2; | |||
| KernelRuntime::~KernelRuntime() {} | |||
| bool KernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { return true; } | |||
| @@ -258,7 +260,7 @@ void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value | |||
| continue; | |||
| } | |||
| if (opt::IsNopNode(real_output_cnode)) { | |||
| if (real_output_cnode->inputs().size() < 2) { | |||
| if (real_output_cnode->inputs().size() < kMinInputSize) { | |||
| MS_LOG(EXCEPTION) << "The input size of output node: " << real_output_cnode->DebugString() | |||
| << " should large than one!"; | |||
| } | |||
| @@ -280,14 +282,14 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end()); | |||
| std::vector<AnfNodePtr> need_alloc_nodes; | |||
| for (size_t i = 0; i < graph_inputs.size(); ++i) { | |||
| auto item = graph_inputs[i]; | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| auto input_node = graph_inputs[i]; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (i < graph_valid_input.size() && !graph_valid_input[i]) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) { | |||
| auto outs = AnfAlgo::GetAllOutput(item); | |||
| if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) { | |||
| auto outs = AnfAlgo::GetAllOutput(input_node); | |||
| for (auto &out : outs) { | |||
| MS_EXCEPTION_IF_NULL(out); | |||
| if (!out->isa<Parameter>()) { | |||
| @@ -299,13 +301,13 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| need_alloc_nodes.push_back(out); | |||
| } | |||
| } | |||
| if (!item->isa<Parameter>()) { | |||
| if (!input_node->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| if (NodeOutputDeviceAddressExist(item, 0)) { | |||
| if (NodeOutputDeviceAddressExist(input_node, 0)) { | |||
| continue; | |||
| } | |||
| need_alloc_nodes.push_back(item); | |||
| need_alloc_nodes.push_back(input_node); | |||
| } | |||
| #if (ENABLE_CPU && !_WIN32) | |||
| bool ps_cache_check = false; | |||
| @@ -358,15 +360,15 @@ void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) | |||
| std::vector<session::KernelWithIndex> non_communication_op; | |||
| // Assign Communicate Op Memory firstly. | |||
| for (const auto &node : nodes) { | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | |||
| MS_EXCEPTION_IF_NULL(item_with_index.first); | |||
| if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) { | |||
| auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | |||
| MS_EXCEPTION_IF_NULL(kernel_with_index.first); | |||
| if (!kernel_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(kernel_with_index.first)) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { | |||
| AssignCommunicationNodeMem(kStaticMem, item_with_index.first); | |||
| if (AnfAlgo::IsCommunicationOp(kernel_with_index.first)) { | |||
| AssignCommunicationNodeMem(kStaticMem, kernel_with_index.first); | |||
| } else { | |||
| non_communication_op.emplace_back(item_with_index); | |||
| non_communication_op.emplace_back(kernel_with_index); | |||
| } | |||
| } | |||
| @@ -595,7 +597,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().size() < 2) { | |||
| if (cnode->inputs().size() < kMinInputSize) { | |||
| // communication node's input should contain itself and at least on input | |||
| MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope(); | |||
| return; | |||
| @@ -27,8 +27,10 @@ using mindspore::memreuse::MemReuseUtilPtr; | |||
| namespace mindspore { | |||
| namespace device { | |||
| constexpr size_t kAlignBytes = 32; | |||
| size_t MemoryManager::GetCommonAlignSize(size_t input_size) { | |||
| return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; | |||
| return (input_size + kMemAlignSize + kAlignBytes - 1) / kMemAlignSize * kMemAlignSize; | |||
| } | |||
| size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) const { | |||