From: @huaweib Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54tags/v1.2.0-rc1
| @@ -24,6 +24,10 @@ void AssignAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | ||||
| std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | ||||
| if (src1_shape.size() == 0 && src0_shape.size() == 0) { | |||||
| src0_shape.insert(src0_shape.begin(), 1); | |||||
| src1_shape.insert(src1_shape.begin(), 1); | |||||
| } | |||||
| if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) { | if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) { | ||||
| MS_LOG(EXCEPTION) << "AssignAdd only support same dim input or tensor * scalar " << src0_shape.size() << " vs " | MS_LOG(EXCEPTION) << "AssignAdd only support same dim input or tensor * scalar " << src0_shape.size() << " vs " | ||||
| << src1_shape.size(); | << src1_shape.size(); | ||||
| @@ -130,7 +130,11 @@ dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::d | |||||
| dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector<size_t> &shape) { | dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector<size_t> &shape) { | ||||
| dnnl::memory::dims dims; | dnnl::memory::dims dims; | ||||
| dims.insert(dims.end(), shape.begin(), shape.end()); | |||||
| if (shape.size() == 0) { | |||||
| dims.insert(dims.end(), 1); | |||||
| } else { | |||||
| dims.insert(dims.end(), shape.begin(), shape.end()); | |||||
| } | |||||
| dnnl::memory::format_tag mem_tag = GetDefaultFormatTag(dims); | dnnl::memory::format_tag mem_tag = GetDefaultFormatTag(dims); | ||||
| dnnl::memory::desc mem_desc(dims, dnnl::memory::data_type::f32, mem_tag); | dnnl::memory::desc mem_desc(dims, dnnl::memory::data_type::f32, mem_tag); | ||||
| return mem_desc; | return mem_desc; | ||||
| @@ -151,7 +151,7 @@ class CtcLossGpuKernel : public GpuKernel { | |||||
| void LaunchSecondHalf(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | void LaunchSecondHalf(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | const std::vector<AddressPtr> &outputs, void *stream_ptr) { | ||||
| cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr); | cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr); | ||||
| int SOffSet = 2 * max_labels_length_host + 1; | |||||
| const int SOffSet = 2 * max_labels_length_host + 1; | |||||
| int log_prob_size = batch * SOffSet * max_time; | int log_prob_size = batch * SOffSet * max_time; | ||||
| if (!ignore_longer_outputs_than_inputs_ && max_labels_length_host > max_time) { | if (!ignore_longer_outputs_than_inputs_ && max_labels_length_host > max_time) { | ||||