|
|
|
@@ -45,36 +45,36 @@ class OneHotGpuFwdKernel : public GpuKernel { |
|
|
|
return true; |
|
|
|
} |
|
|
|
bool Init(const CNodePtr &kernel_node) override { |
|
|
|
int axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis")); |
|
|
|
auto input = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); |
|
|
|
auto output = AnfAlgo::GetOutputInferShape(kernel_node, 0); |
|
|
|
int input_size = SizeToInt(input.size()); |
|
|
|
const int default_axis = -1; |
|
|
|
int64_t axis = GetAttr<int64_t>(kernel_node, "axis"); |
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); |
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); |
|
|
|
int64_t input_dims = static_cast<int64_t>(input_shape.size()); |
|
|
|
if (axis >= input_dims) { |
|
|
|
MS_LOG(ERROR) << "invalid one hot axis value: " << axis << " for input dims size: " << input_shape.size(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
const int64_t default_axis = -1; |
|
|
|
|
|
|
|
// Compress arbitrary tensor dimensions into three dimensions (left_dims, depth, right_dims). |
|
|
|
for (int i = 0; i < input_size; i++) { |
|
|
|
auto dim_size = input[IntToSize(i)]; |
|
|
|
if (axis == default_axis || i < axis) { |
|
|
|
for (size_t i = 0; i < input_shape.size(); i++) { |
|
|
|
auto dim_size = input_shape[i]; |
|
|
|
if (axis == default_axis || i < IntToSize(axis)) { |
|
|
|
left_dim_size_ *= dim_size; |
|
|
|
} |
|
|
|
if (axis != default_axis && i >= axis) { |
|
|
|
if (axis != default_axis && i >= IntToSize(axis)) { |
|
|
|
right_dim_size_ *= dim_size; |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto size : input) { |
|
|
|
for (auto size : input_shape) { |
|
|
|
input_size_ *= size; |
|
|
|
} |
|
|
|
for (auto size : output) { |
|
|
|
for (auto size : output_shape) { |
|
|
|
output_size_ *= size; |
|
|
|
} |
|
|
|
if (axis >= input_size) { |
|
|
|
MS_LOG(ERROR) << "invalid one hot axis value: " << axis << " for input dims size: " << input.size(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (axis == default_axis) { |
|
|
|
depth_ = output[output.size() - 1]; |
|
|
|
depth_ = output_shape[output_shape.size() - 1]; |
|
|
|
} else { |
|
|
|
depth_ = output[IntToSize(axis)]; |
|
|
|
depth_ = output_shape[IntToSize(axis)]; |
|
|
|
} |
|
|
|
InitSizeLists(); |
|
|
|
return true; |
|
|
|
|