|
|
|
@@ -35,12 +35,14 @@ constexpr int LEFT = 0; |
|
|
|
constexpr int RIGHT = 1; |
|
|
|
constexpr size_t kMirrorPadGradInputsNum = 2; |
|
|
|
constexpr size_t kMirrorPadGradOutputsNum = 1; |
|
|
|
constexpr size_t kPadMaxSupportDim = 4; |
|
|
|
|
|
|
|
void extract_paddings(const int64_t *paddings_arg, int64_t padd_dim, int64_t *extracted_paddings) { |
|
|
|
template <typename T> |
|
|
|
void extract_paddings(const T *paddings_arg, int64_t padd_dim, int64_t *extracted_paddings) { |
|
|
|
const int64_t paddings_offset = MAX_PADDINGS - padd_dim; |
|
|
|
for (int64_t i = 0; i < padd_dim; i++) { |
|
|
|
extracted_paddings[(paddings_offset + i) * PADDING_SIZE] = paddings_arg[i * PADDING_SIZE]; |
|
|
|
extracted_paddings[(paddings_offset + i) * PADDING_SIZE + 1] = paddings_arg[i * PADDING_SIZE + 1]; |
|
|
|
extracted_paddings[(paddings_offset + i) * PADDING_SIZE] = int64_t(paddings_arg[i * PADDING_SIZE]); |
|
|
|
extracted_paddings[(paddings_offset + i) * PADDING_SIZE + 1] = int64_t(paddings_arg[i * PADDING_SIZE + 1]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -57,6 +59,7 @@ void MirrorPadGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); |
|
|
|
std::string mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "mode"); |
|
|
|
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); |
|
|
|
pad_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1); |
|
|
|
if (mode == "REFLECT") { |
|
|
|
mode_ = 0; |
|
|
|
} else if (mode == "SYMMETRIC") { |
|
|
|
@@ -68,14 +71,8 @@ void MirrorPadGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
|
|
|
|
std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); |
|
|
|
shape_size_ = input_shape.size(); |
|
|
|
if (shape_size_ == 4) { // shape adjustment from 2d/3d to 4d |
|
|
|
} else if (shape_size_ == 3) { |
|
|
|
(void)input_shape.insert(input_shape.begin(), 1); // batch padding |
|
|
|
shape_size_ = 4; |
|
|
|
} else if (shape_size_ == 2) { |
|
|
|
(void)input_shape.insert(input_shape.begin(), 2, 1); // channel padding |
|
|
|
shape_size_ = 4; |
|
|
|
} |
|
|
|
(void)input_shape.insert(input_shape.begin(), kPadMaxSupportDim - shape_size_, 1); |
|
|
|
shape_size_ = kPadMaxSupportDim; |
|
|
|
|
|
|
|
for (size_t i = 0; i < shape_size_; ++i) { |
|
|
|
tensor_size_ *= input_shape[i]; |
|
|
|
@@ -126,20 +123,28 @@ bool MirrorPadGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) { |
|
|
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMirrorPadGradInputsNum, kernel_name_); |
|
|
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMirrorPadGradOutputsNum, kernel_name_); |
|
|
|
if (dtype_ == kNumberTypeFloat16) { |
|
|
|
LaunchKernel<float16>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeFloat32) { |
|
|
|
LaunchKernel<float>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeFloat64) { |
|
|
|
LaunchKernel<double>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeInt32) { |
|
|
|
LaunchKernel<int>(inputs, workspace, outputs); |
|
|
|
if (dtype_ == kNumberTypeFloat16 && pad_dtype_ == kNumberTypeInt32) { |
|
|
|
LaunchKernel<float16, int32_t>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeFloat32 && pad_dtype_ == kNumberTypeInt32) { |
|
|
|
LaunchKernel<float, int32_t>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeFloat64 && pad_dtype_ == kNumberTypeInt32) { |
|
|
|
LaunchKernel<double, int32_t>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeInt32 && pad_dtype_ == kNumberTypeInt32) { |
|
|
|
LaunchKernel<int, int32_t>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeFloat16 && pad_dtype_ == kNumberTypeInt64) { |
|
|
|
LaunchKernel<float16, int64_t>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeFloat32 && pad_dtype_ == kNumberTypeInt64) { |
|
|
|
LaunchKernel<float, int64_t>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeFloat64 && pad_dtype_ == kNumberTypeInt64) { |
|
|
|
LaunchKernel<double, int64_t>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeInt32 && pad_dtype_ == kNumberTypeInt64) { |
|
|
|
LaunchKernel<int, int64_t>(inputs, workspace, outputs); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ |
|
|
|
<< "', the dtype of 'input_x' should be float16, float32, float64, or int32, but got " |
|
|
|
<< TypeIdLabel(dtype_); |
|
|
|
<< "', the dtype of 'input_x' should be float16, float32, float64, or int32, and the dtype of " |
|
|
|
"'paddings' should be int32 or int64, but got " |
|
|
|
<< TypeIdLabel(dtype_) << " and " << TypeIdLabel(pad_dtype_); |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -161,11 +166,12 @@ void MirrorPadGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void MirrorPadGradCPUKernel::MirrorPadGrad_Width_Height(const size_t size, const T *interim_dy, const int64_t dx_height, |
|
|
|
const int64_t dx_width, const int64_t dy_height, |
|
|
|
const int64_t dy_width, const int64_t padd_dim, |
|
|
|
const int64_t *paddings_arg, int64_t mode, T *dx) const { |
|
|
|
template <typename T1, typename T2> |
|
|
|
void MirrorPadGradCPUKernel::MirrorPadGrad_Width_Height(const size_t size, const T1 *interim_dy, |
|
|
|
const int64_t dx_height, const int64_t dx_width, |
|
|
|
const int64_t dy_height, const int64_t dy_width, |
|
|
|
const int64_t padd_dim, const T2 *paddings_arg, int64_t mode, |
|
|
|
T1 *dx) const { |
|
|
|
int64_t paddings[MAX_PADDINGS * PADDING_SIZE]; // local and fixed size to keep in registers |
|
|
|
for (int i = 0; i < MAX_PADDINGS * PADDING_SIZE; i++) { |
|
|
|
paddings[i] = 0; // init all to 0 |
|
|
|
@@ -176,7 +182,11 @@ void MirrorPadGradCPUKernel::MirrorPadGrad_Width_Height(const size_t size, const |
|
|
|
int64_t ap2_x = paddings[WIDTH] + dx_width - 1; |
|
|
|
int64_t ap1_y = paddings[HEIGHT]; |
|
|
|
int64_t ap2_y = paddings[HEIGHT] + dx_height - 1; |
|
|
|
|
|
|
|
if (dx_width == 0 || dx_height == 0) { |
|
|
|
MS_LOG(EXCEPTION) |
|
|
|
<< "For MirrorPadGrad_Width_Height, the input argument 'dx_height' and 'dx_width' should not be 0, but got " |
|
|
|
<< "dy_height: " << dx_height << " dy_width: " << dx_width; |
|
|
|
} |
|
|
|
for (size_t pos = 0; pos < size; ++pos) { |
|
|
|
int64_t dx_block_num = (SizeToLong(pos) / dx_width) / dx_height; |
|
|
|
const int64_t grad_x = (SizeToLong(pos) % dx_width) + paddings[WIDTH]; |
|
|
|
@@ -225,12 +235,17 @@ void MirrorPadGradCPUKernel::MirrorPadGrad_Width_Height(const size_t size, const |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void MirrorPadGradCPUKernel::MirrorPadGradBatchChannel(const size_t size, T *dy, T *interim_dy, |
|
|
|
template <typename T1, typename T2> |
|
|
|
void MirrorPadGradCPUKernel::MirrorPadGradBatchChannel(const size_t size, T1 *dy, T1 *interim_dy, |
|
|
|
const int64_t dx_batches, const int64_t dx_channels, |
|
|
|
const int64_t dy_height, const int64_t dy_width, |
|
|
|
const int64_t padd_dim, const int64_t *paddings_arg, |
|
|
|
const int64_t padd_dim, const T2 *paddings_arg, |
|
|
|
int64_t mode) const { |
|
|
|
if (dy_height == 0 || dy_width == 0 || dx_channels == 0) { |
|
|
|
MS_LOG(EXCEPTION) << "For MirrorPadGradBatchChannel, the input argument 'dy_height', 'dy_width' and 'dx_channels' " |
|
|
|
"should not be 0, but got " |
|
|
|
<< "dy_height: " << dy_height << " dy_width: " << dy_width << " dx_channels: " << dx_channels; |
|
|
|
} |
|
|
|
int64_t paddings[MAX_PADDINGS * PADDING_SIZE]; // local and fixed size to keep in registers |
|
|
|
for (int i = 0; i < MAX_PADDINGS * PADDING_SIZE; i++) { |
|
|
|
paddings[i] = 0; // init all to 0 |
|
|
|
@@ -251,7 +266,7 @@ void MirrorPadGradCPUKernel::MirrorPadGradBatchChannel(const size_t size, T *dy, |
|
|
|
const int64_t interim_y = (SizeToLong(pos) / dy_width) % dy_height; |
|
|
|
const int64_t interim_channel = block_num % dx_channels; |
|
|
|
const int64_t interim_batch = block_num / dx_channels; |
|
|
|
interim_dy[pos] = T(0); // init |
|
|
|
interim_dy[pos] = T1(0); // init |
|
|
|
// map cur interim channel and batch to equivalent in padded dy array |
|
|
|
const int64_t equiv_dy_channel = interim_channel + paddings[CHANNEL]; |
|
|
|
const int64_t equiv_dy_batch = interim_batch + paddings[BATCH]; |
|
|
|
@@ -274,21 +289,21 @@ void MirrorPadGradCPUKernel::MirrorPadGradBatchChannel(const size_t size, T *dy, |
|
|
|
} |
|
|
|
equiv_block_num = ((target_batch * dy_channels) + target_channel); |
|
|
|
// Copy data and set value at input to 0 to avoid duplicates in reflect mode |
|
|
|
interim_dy[pos] = T(interim_dy[pos] + dy[(equiv_block_num * dy_height + interim_y) * dy_width + interim_x]); |
|
|
|
dy[(equiv_block_num * dy_height + interim_y) * dy_width + interim_x] = T(0); |
|
|
|
interim_dy[pos] = T1(interim_dy[pos] + dy[(equiv_block_num * dy_height + interim_y) * dy_width + interim_x]); |
|
|
|
dy[(equiv_block_num * dy_height + interim_y) * dy_width + interim_x] = T1(0); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
template <typename T1, typename T2> |
|
|
|
void MirrorPadGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
const std::vector<AddressPtr> &workspace, |
|
|
|
const std::vector<AddressPtr> &outputs) const { |
|
|
|
auto *inputs_addr = reinterpret_cast<T *>(inputs[0]->addr); |
|
|
|
auto *paddings = reinterpret_cast<int64_t *>(inputs[1]->addr); |
|
|
|
auto *interim = reinterpret_cast<T *>(workspace[0]->addr); |
|
|
|
auto *outputs_addr = reinterpret_cast<T *>(outputs[0]->addr); |
|
|
|
auto *inputs_addr = reinterpret_cast<T1 *>(inputs[0]->addr); |
|
|
|
auto *paddings = reinterpret_cast<T2 *>(inputs[1]->addr); |
|
|
|
auto *interim = reinterpret_cast<T1 *>(workspace[0]->addr); |
|
|
|
auto *outputs_addr = reinterpret_cast<T1 *>(outputs[0]->addr); |
|
|
|
|
|
|
|
MirrorPadGradBatchChannel(workspace_size_, inputs_addr, interim, output_shape_[0], output_shape_[1], input_shape_[2], |
|
|
|
input_shape_[3], num_paddings_, paddings, mode_); |
|
|
|
|