Merge pull request !6437 from baihuawei/clear-warningstags/v1.0.0
| @@ -21,6 +21,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| const int kMaxLSTMLayer = 100; | |||
| void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| #ifdef PLATFORM_86 | |||
| _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); | |||
| @@ -42,7 +43,7 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; | |||
| weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; | |||
| bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; | |||
| dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; | |||
| dim dst_dims = {seq_len_, batch_size_, static_cast<int64_t>(hidden_size_) * num_directions_}; | |||
| dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | |||
| dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | |||
| dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); | |||
| @@ -89,6 +90,9 @@ void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| if (num_layers_ <= 0) { | |||
| MS_LOG(EXCEPTION) << "layers must be greater than zero!"; | |||
| } | |||
| if (num_layers_ > kMaxLSTMLayer) { | |||
| MS_LOG(EXCEPTION) << "layers must be lower than 100!"; | |||
| } | |||
| for (int i = 0; i < num_layers_; ++i) { | |||
| weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); | |||
| weight_h_size_ += gate_size * hidden_size_; | |||
| @@ -121,9 +125,8 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| if (has_bias_) { | |||
| bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_); | |||
| } else { | |||
| auto ret = | |||
| memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0, prim_desc_.bias_desc().get_size()); | |||
| if (ret != 0) { | |||
| if (memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0, | |||
| prim_desc_.bias_desc().get_size())) { | |||
| MS_LOG(EXCEPTION) << "bias memset error"; | |||
| } | |||
| } | |||
| @@ -22,6 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| const int kMaxLSTMLayer = 100; | |||
| void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| using tag = dnnl::memory::format_tag; | |||
| @@ -38,7 +39,7 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; | |||
| weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; | |||
| bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; | |||
| dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; | |||
| dim dst_dims = {seq_len_, batch_size_, static_cast<int64_t>(hidden_size_) * num_directions_}; | |||
| dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | |||
| dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; | |||
| dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); | |||
| @@ -107,6 +108,9 @@ void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| if (num_layers_ <= 0) { | |||
| MS_LOG(EXCEPTION) << "layers must be greater than zero!"; | |||
| } | |||
| if (num_layers_ > kMaxLSTMLayer) { | |||
| MS_LOG(EXCEPTION) << "layers must be lower than 100!"; | |||
| } | |||
| for (int i = 0; i < num_layers_; ++i) { | |||
| weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); | |||
| weight_h_size_ += gate_size * hidden_size_; | |||
| @@ -25,7 +25,6 @@ | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh" | |||
| namespace mindspore { | |||
| @@ -33,12 +32,7 @@ namespace kernel { | |||
| template <typename T> | |||
| class MultinomialGpuKernel : public GpuKernel { | |||
| public: | |||
| MultinomialGpuKernel() | |||
| : input_size_0_(0), | |||
| output_size_(0), | |||
| distributions_(0), | |||
| workspace_size_(sizeof(curandState)), | |||
| replacement_(true) {} | |||
| MultinomialGpuKernel() : input_size_0_(0), output_size_(0), distributions_(0), workspace_size_(sizeof(curandState)) {} | |||
| ~MultinomialGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -54,52 +48,13 @@ class MultinomialGpuKernel : public GpuKernel { | |||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||
| int categories = SizeToInt(inputs[0]->size / sizeof(T)) / distributions_; | |||
| int num_sample = SizeToInt(outputs[0]->size / sizeof(int)) / distributions_; | |||
| // check input | |||
| CheckPeram(input_addr, cum_sum_input, categories, stream_ptr); | |||
| if (replacement_) { | |||
| NormInput(cum_sum_input, IntToSize(distributions_), IntToSize(categories), | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaStreamSynchronize failed."); | |||
| Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), | |||
| IntToSize(categories), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| void CheckPeram(const T *input_addr, T *cum_sum_input, int categories, void *stream_ptr) { | |||
| T *flag = nullptr; | |||
| T *cflag = nullptr; | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast<void **>(&cflag), sizeof(T)), "cudaMalloc failed."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&flag, sizeof(T)), "cudaMallocHost failed."); | |||
| CalFloatStatus(input_size_0_ / sizeof(T), input_addr, cflag, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaStreamSynchronize failed."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(flag, cflag, sizeof(T), cudaMemcpyDeviceToHost), "cudaMemcpyAsync failed."); | |||
| if (*flag > 0) { | |||
| MS_LOG(EXCEPTION) << "Input is invalid (containing NaN, -inf or inf)"; | |||
| } | |||
| CheckNonNeg(input_size_0_ / sizeof(T), input_addr, cflag, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaStreamSynchronize failed."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(flag, cflag, sizeof(T), cudaMemcpyDeviceToHost), "cudaMemcpyAsync failed."); | |||
| if (*flag > 0) { | |||
| MS_LOG(EXCEPTION) << "Input is invalid (input element < 0)"; | |||
| } | |||
| CumSum(input_addr, cum_sum_input, cum_sum_input, IntToSize(distributions_), IntToSize(categories), 1, | |||
| IntToSize(categories), 1, false, false, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaStreamSynchronize failed."); | |||
| CheckZero(IntToSize(distributions_), IntToSize(categories), cum_sum_input, cflag, | |||
| NormInput(cum_sum_input, IntToSize(distributions_), IntToSize(categories), | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaStreamSynchronize failed."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(flag, cflag, sizeof(T), cudaMemcpyDeviceToHost), "cudaMemcpyAsync failed."); | |||
| if (*flag > 0) { | |||
| MS_LOG(EXCEPTION) << "Input is invalid (sum <= 0)"; | |||
| } | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cflag), "cudaFree failed."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(flag), "cudaFreeHost failed."); | |||
| Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), | |||
| IntToSize(categories), reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| @@ -127,15 +82,10 @@ class MultinomialGpuKernel : public GpuKernel { | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| output_size_ = sizeof(int); | |||
| workspace_size_ = sizeof(int); | |||
| replacement_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("replacement")); | |||
| if (replacement_) { | |||
| for (size_t i = 0; i < output_shape.size(); i++) { | |||
| output_size_ *= output_shape[i]; | |||
| } | |||
| } | |||
| if (replacement_) { | |||
| workspace_size_ = output_size_; | |||
| for (size_t i = 0; i < output_shape.size(); i++) { | |||
| output_size_ *= output_shape[i]; | |||
| } | |||
| workspace_size_ = output_size_; | |||
| seed_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); | |||
| InitSizeLists(); | |||
| return true; | |||
| @@ -155,7 +105,6 @@ class MultinomialGpuKernel : public GpuKernel { | |||
| size_t output_size_; | |||
| size_t distributions_; | |||
| size_t workspace_size_; | |||
| bool replacement_; | |||
| int seed_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| @@ -39,6 +39,9 @@ TENSOR_GETITEM = "tensor getitem" | |||
| SET_ITEM_BY_ONE_TENSOR = 0 | |||
| SET_ITEM_BY_TUPLE_OF_TENSOR = 1 | |||
| @constexpr | |||
| def raise_value_error(msg): | |||
| raise ValueError(msg) | |||
| @constexpr | |||
| def raise_index_error(msg): | |||
| @@ -255,11 +255,10 @@ def multinomial(inputs, num_sample, replacement=True, seed=0): | |||
| shape = P.Shape() | |||
| reshape = P.Reshape() | |||
| if inputs.dim() != 1 and inputs.dim() != 2: | |||
| raise ValueError("inputs dim must be 1d or 2d") | |||
| const_utils.raise_value_error("inputs dim must be 1d or 2d") | |||
| if not replacement: | |||
| P.Multinomial(replacement=replacement, seed=seed)(inputs, num_sample) | |||
| if shape(inputs)[-1] < num_sample: | |||
| raise ValueError("num_sample must be less than shape(input)[-1] without replacement") | |||
| const_utils.raise_value_error("num_sample must be less than shape(input)[-1] without replacement") | |||
| n_dist = 1 | |||
| if len(shape(inputs)) > 1: | |||
| n_dist = shape(inputs)[-2] | |||
| @@ -269,4 +268,4 @@ def multinomial(inputs, num_sample, replacement=True, seed=0): | |||
| vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6) | |||
| _, indices = P.TopK()(vals, num_sample) | |||
| return indices | |||
| return P.Multinomial(replacement=replacement, seed=seed)(inputs, num_sample) | |||
| return P.Multinomial(seed=seed)(inputs, num_sample) | |||
| @@ -433,8 +433,6 @@ class Multinomial(PrimitiveWithInfer): | |||
| Args: | |||
| seed (int): Seed data is used as entropy source for Random number engines to generate pseudo-random numbers. | |||
| Must be non-negative. Default: 0. | |||
| replacement(bool): Whether to draw with replacement or not. | |||
| Inputs: | |||
| - **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 | |||
| dimensions. | |||
| @@ -445,16 +443,15 @@ class Multinomial(PrimitiveWithInfer): | |||
| Examples: | |||
| >>> input = Tensor([0., 9., 4., 0.], mstype.float32) | |||
| >>> multinomial = P.Multinomial(replacement=True, seed=10) | |||
| >>> multinomial = P.Multinomial(seed=10) | |||
| >>> output = multinomial(input, 2) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, replacement=True, seed=0): | |||
| def __init__(self, seed=0): | |||
| """init""" | |||
| validator.check_value_type("seed", seed, [int], self.name) | |||
| validator.check_integer("seed", seed, 0, Rel.GE, self.name) | |||
| validator.check_value_type("replacement", replacement, [bool], self.name) | |||
| self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) | |||
| def __infer__(self, inputs, num_samples): | |||