Browse Source

change pynative batchnorm same as graph mode

tags/v0.6.0-beta
chujinjin 5 years ago
parent
commit
c880774942
6 changed files with 25 additions and 16 deletions
  1. +7
    -1
      mindspore/ccsrc/device/ascend/ascend_device_address.cc
  2. +1
    -0
      mindspore/ccsrc/device/ascend/ascend_device_address.h
  3. +2
    -0
      mindspore/ccsrc/device/device_address.h
  4. +4
    -1
      mindspore/ccsrc/device/kernel_runtime.cc
  5. +9
    -11
      mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
  6. +2
    -3
      mindspore/nn/layer/normalization.py

+ 7
- 1
mindspore/ccsrc/device/ascend/ascend_device_address.cc View File

@@ -174,7 +174,13 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
} else {
host_shape = trans::PaddingShapeTo4d(host_shape);
if (host_shape_.empty()) {
host_shape = trans::PaddingShapeTo4d(host_shape);
} else {
host_shape.clear();
(void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(host_shape), IntToSize);
}

device_shape = trans::TransShapeToDevice(host_shape, format_);
}
if (type_id_ != type) {


+ 1
- 0
mindspore/ccsrc/device/ascend/ascend_device_address.h View File

@@ -47,6 +47,7 @@ class AscendDeviceAddress : public DeviceAddress {
bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt,
const std::vector<int> &host_shape, TypeId host_type, size_t slot, Debugger *debugger) const;
#endif

private:
bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const;
bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,


+ 2
- 0
mindspore/ccsrc/device/device_address.h View File

@@ -63,6 +63,7 @@ class DeviceAddress {
size_t GetSize() const { return size_; }
std::string format() const { return format_; }
TypeId type_id() const { return type_id_; }
void set_host_shape(const std::vector<int> &shape) { host_shape_ = shape; }
virtual void set_status(DeviceAddressStatus status) {}
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
@@ -77,6 +78,7 @@ class DeviceAddress {
string format_{"DefaultFormat"};
TypeId type_id_{kNumberTypeFloat16};
bool from_mem_pool_{false};
std::vector<int> host_shape_{};
friend class KernelRuntime;
friend class MemoryManager;
friend class mindspore::device::ascend::tasksink::TaskGenerator;


+ 4
- 1
mindspore/ccsrc/device/kernel_runtime.cc View File

@@ -258,6 +258,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
MS_EXCEPTION_IF_NULL(device_address);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
if (!ret) {
@@ -506,7 +507,9 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in
}
std::string output_format = AnfAlgo::GetOutputFormat(node, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
AnfAlgo::SetOutputAddr(CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type), i, node.get());
auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type);
device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
AnfAlgo::SetOutputAddr(device_address, i, node.get());
}
}



+ 9
- 11
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc View File

@@ -238,16 +238,11 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
if (context_ptr->execution_mode() == kPynativeMode) {
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
} else {
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
}
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
if (context_ptr->ir_fusion_flag()) {
AddAscendBackendOptionalIRFusion(ir_fusion_pm.get());
@@ -287,8 +282,11 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());


+ 2
- 3
mindspore/nn/layer/normalization.py View File

@@ -84,14 +84,13 @@ class _BatchNorm(Cell):
self.dtype = P.DType()
self.reshape = P.Reshape()
self.is_ascend = context.get_context("device_target") == "Ascend"
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
self.momentum = 1.0 - momentum
if context.get_context("enable_ge"):
self.is_ge_backend = True
else:
self.is_ge_backend = False

if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
if self.is_ge_backend or self.is_ascend:
self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps)
else:
@@ -153,7 +152,7 @@ class _BatchNorm(Cell):
if self.is_ge_backend and self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
elif self.is_ge_backend or self.is_ascend:
if self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)


Loading…
Cancel
Save