| @@ -174,7 +174,13 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int | |||||
| if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { | if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { | ||||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | device_shape = trans::TransShapeToDevice(host_shape, format_); | ||||
| } else { | } else { | ||||
| host_shape = trans::PaddingShapeTo4d(host_shape); | |||||
| if (host_shape_.empty()) { | |||||
| host_shape = trans::PaddingShapeTo4d(host_shape); | |||||
| } else { | |||||
| host_shape.clear(); | |||||
| (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(host_shape), IntToSize); | |||||
| } | |||||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | device_shape = trans::TransShapeToDevice(host_shape, format_); | ||||
| } | } | ||||
| if (type_id_ != type) { | if (type_id_ != type) { | ||||
| @@ -47,6 +47,7 @@ class AscendDeviceAddress : public DeviceAddress { | |||||
| bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt, | bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt, | ||||
| const std::vector<int> &host_shape, TypeId host_type, size_t slot, Debugger *debugger) const; | const std::vector<int> &host_shape, TypeId host_type, size_t slot, Debugger *debugger) const; | ||||
| #endif | #endif | ||||
| private: | private: | ||||
| bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const; | bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const; | ||||
| bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, | bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, | ||||
| @@ -63,6 +63,7 @@ class DeviceAddress { | |||||
| size_t GetSize() const { return size_; } | size_t GetSize() const { return size_; } | ||||
| std::string format() const { return format_; } | std::string format() const { return format_; } | ||||
| TypeId type_id() const { return type_id_; } | TypeId type_id() const { return type_id_; } | ||||
| void set_host_shape(const std::vector<int> &shape) { host_shape_ = shape; } | |||||
| virtual void set_status(DeviceAddressStatus status) {} | virtual void set_status(DeviceAddressStatus status) {} | ||||
| virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } | virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } | ||||
| virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } | virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } | ||||
| @@ -77,6 +78,7 @@ class DeviceAddress { | |||||
| string format_{"DefaultFormat"}; | string format_{"DefaultFormat"}; | ||||
| TypeId type_id_{kNumberTypeFloat16}; | TypeId type_id_{kNumberTypeFloat16}; | ||||
| bool from_mem_pool_{false}; | bool from_mem_pool_{false}; | ||||
| std::vector<int> host_shape_{}; | |||||
| friend class KernelRuntime; | friend class KernelRuntime; | ||||
| friend class MemoryManager; | friend class MemoryManager; | ||||
| friend class mindspore::device::ascend::tasksink::TaskGenerator; | friend class mindspore::device::ascend::tasksink::TaskGenerator; | ||||
| @@ -258,6 +258,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) { | |||||
| std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); | std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); | ||||
| auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); | auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); | ||||
| auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); | auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); | ||||
| device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i)); | |||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); | auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); | ||||
| if (!ret) { | if (!ret) { | ||||
| @@ -506,7 +507,9 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in | |||||
| } | } | ||||
| std::string output_format = AnfAlgo::GetOutputFormat(node, i); | std::string output_format = AnfAlgo::GetOutputFormat(node, i); | ||||
| auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); | auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); | ||||
| AnfAlgo::SetOutputAddr(CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type), i, node.get()); | |||||
| auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type); | |||||
| device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i)); | |||||
| AnfAlgo::SetOutputAddr(device_address, i, node.get()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -238,16 +238,11 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| } | } | ||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | auto optimizer = std::make_shared<GraphOptimizer>(); | ||||
| auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | ||||
| if (context_ptr->execution_mode() == kPynativeMode) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); | |||||
| } else { | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); | |||||
| } | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ||||
| if (context_ptr->ir_fusion_flag()) { | if (context_ptr->ir_fusion_flag()) { | ||||
| AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); | AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); | ||||
| @@ -287,8 +282,11 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne | |||||
| } | } | ||||
| auto optimizer = std::make_shared<GraphOptimizer>(); | auto optimizer = std::make_shared<GraphOptimizer>(); | ||||
| auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | ||||
| ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); | ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ||||
| @@ -84,14 +84,13 @@ class _BatchNorm(Cell): | |||||
| self.dtype = P.DType() | self.dtype = P.DType() | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.is_ascend = context.get_context("device_target") == "Ascend" | self.is_ascend = context.get_context("device_target") == "Ascend" | ||||
| self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE | |||||
| self.momentum = 1.0 - momentum | self.momentum = 1.0 - momentum | ||||
| if context.get_context("enable_ge"): | if context.get_context("enable_ge"): | ||||
| self.is_ge_backend = True | self.is_ge_backend = True | ||||
| else: | else: | ||||
| self.is_ge_backend = False | self.is_ge_backend = False | ||||
| if self.is_graph_mode and (self.is_ge_backend or self.is_ascend): | |||||
| if self.is_ge_backend or self.is_ascend: | |||||
| self.bn_train = P.BatchNorm(is_training=True, | self.bn_train = P.BatchNorm(is_training=True, | ||||
| epsilon=self.eps) | epsilon=self.eps) | ||||
| else: | else: | ||||
| @@ -153,7 +152,7 @@ class _BatchNorm(Cell): | |||||
| if self.is_ge_backend and self.is_global: | if self.is_ge_backend and self.is_global: | ||||
| axes, re_shape = _shape_infer(F.shape(x), self.num_features) | axes, re_shape = _shape_infer(F.shape(x), self.num_features) | ||||
| y = self._global_sync(x, axes, re_shape) | y = self._global_sync(x, axes, re_shape) | ||||
| elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend): | |||||
| elif self.is_ge_backend or self.is_ascend: | |||||
| if self.is_global: | if self.is_global: | ||||
| axes, re_shape = _shape_infer(F.shape(x), self.num_features) | axes, re_shape = _shape_infer(F.shape(x), self.num_features) | ||||
| y = self._global_sync(x, axes, re_shape) | y = self._global_sync(x, axes, re_shape) | ||||