Merge pull request !2405 from chenzhongming/mastertags/v0.6.0-beta
| @@ -248,7 +248,7 @@ checkopts() | |||||
| done | done | ||||
| } | } | ||||
| checkopts "$@" | checkopts "$@" | ||||
| echo "---------------- mindspore: build start ----------------" | |||||
| echo "---------------- MindSpore: build start ----------------" | |||||
| mkdir -pv "${BUILD_PATH}/package/mindspore/lib" | mkdir -pv "${BUILD_PATH}/package/mindspore/lib" | ||||
| git submodule update --init graphengine | git submodule update --init graphengine | ||||
| if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then | if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then | ||||
| @@ -36,7 +36,7 @@ class Monitor { | |||||
| ~Monitor() = default; | ~Monitor() = default; | ||||
| // Functor for Perf Monitor main loop. | // Functor for Perf Monitor main loop. | ||||
| // This function will be the entry point of Mindspore::Dataset::Task | |||||
| // This function will be the entry point of mindspore::Dataset::Task | |||||
| Status operator()(); | Status operator()(); | ||||
| int64_t GetSamplingInterval() { return sampling_interval_; } | int64_t GetSamplingInterval() { return sampling_interval_; } | ||||
| @@ -29,7 +29,7 @@ | |||||
| // brief mindspore namespace. | // brief mindspore namespace. | ||||
| // | // | ||||
| // mindspore namespace is the top level namespace of Mindsporeession project. | |||||
| // mindspore namespace is the top level namespace of MindSpore project. | |||||
| // Other namespace should be a sub namespace of mindspore namespace in the ME project. | // Other namespace should be a sub namespace of mindspore namespace in the ME project. | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -91,7 +91,7 @@ using mindspore::device::DeviceAddress; | |||||
| using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>; | using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>; | ||||
| // brief mindspore namespace. | // brief mindspore namespace. | ||||
| // | // | ||||
| // mindspore namespace is the top level namespace of Mindsporeession project. | |||||
| // mindspore namespace is the top level namespace of MindSpore project. | |||||
| // Other namespace should be a sub namespace of mindspore namespace in the ME project. | // Other namespace should be a sub namespace of mindspore namespace in the ME project. | ||||
| namespace mindspore { | namespace mindspore { | ||||
| // brief mindspore::tensor namespace | // brief mindspore::tensor namespace | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <thrust/execution_policy.h> | #include <thrust/execution_policy.h> | ||||
| #include <thrust/reduce.h> | #include <thrust/reduce.h> | ||||
| #include <thrust/pair.h> | #include <thrust/pair.h> | ||||
| #include "fake_quant_per_channel_impl.cuh" | |||||
| #include "fake_quant_perchannel_impl.cuh" | |||||
| #include "device/gpu/cuda_common.h" | #include "device/gpu/cuda_common.h" | ||||
| /** | /** | ||||
| @@ -113,44 +113,6 @@ void CalFakeQuantizePerChannel(const float *input, float *output, const int tota | |||||
| input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric); | input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric); | ||||
| } | } | ||||
| /** | |||||
| * UpdateInputMinMaxPerChannel or UpdateInputMinMaxPerChannel With EMA. | |||||
| * @param input_min | |||||
| * @param input_max | |||||
| * @param min | |||||
| * @param max | |||||
| * @return | |||||
| */ | |||||
| __global__ void UpdateInputMinMaxPerChannel(float *input_min, float *input_max, float *input, int channels, | |||||
| int per_channel_nums, bool ema, float ema_decay) { | |||||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { | |||||
| thrust::pair<float *, float *> sum = | |||||
| thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); | |||||
| if (ema) { | |||||
| input_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i]; | |||||
| input_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i]; | |||||
| } else { | |||||
| input_min[i] = sum.first[0]; | |||||
| input_max[i] = sum.second[0]; | |||||
| } | |||||
| input_min[i] = input_min[i] > 0 ? 0 : input_min[i]; | |||||
| input_max[i] = input_max[i] < 0 ? 0 : input_max[i]; | |||||
| } | |||||
| } | |||||
| __global__ void UpdateInputMinMaxPerChannelWithEMA(float *input_min, float *input_max, float min, float max, | |||||
| const float decay) { | |||||
| *input_min = decay * (min) + (1 - decay) * (*input_min); | |||||
| *input_max = decay * (max) + (1 - decay) * (*input_max); | |||||
| } | |||||
| void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, const int total_size, const int channel_size, | |||||
| const float ema_decay, const bool ema, cudaStream_t cuda_stream) { | |||||
| int per_channel_num = total_size / channel_size; | |||||
| UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>( | |||||
| input_min, input_max, input, channel_size, per_channel_num, ema, ema_decay); | |||||
| } | |||||
| __global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, | __global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, | ||||
| const int total_size, const int channel_size, const float *nudge_min, | const int total_size, const int channel_size, const float *nudge_min, | ||||
| const float *nudge_max) { | const float *nudge_max) { | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include <thrust/device_vector.h> | #include <thrust/device_vector.h> | ||||
| #include <thrust/pair.h> | #include <thrust/pair.h> | ||||
| #include "device/gpu/cuda_common.h" | #include "device/gpu/cuda_common.h" | ||||
| #include "fake_quant_impl.cuh" | |||||
| #include "fake_quant_perlayer_impl.cuh" | |||||
| __global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min, | __global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min, | ||||
| const float *nudge_max, const float *scale) { | const float *nudge_max, const float *scale) { | ||||
| @@ -0,0 +1,104 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <thrust/extrema.h> | |||||
| #include <thrust/device_vector.h> | |||||
| #include <thrust/execution_policy.h> | |||||
| #include <thrust/reduce.h> | |||||
| #include <thrust/pair.h> | |||||
| #include "minmax_update_impl.cuh" | |||||
| #include "device/gpu/cuda_common.h" | |||||
| __global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min, | |||||
| float *output_max, const float min, const float max, const float decay, | |||||
| const float symmetric) { | |||||
| output_min[0] = decay * (min) + (1 - decay) * (input_min[0]); | |||||
| output_min[0] = input_min[0] > 0 ? 0 : input_min[0]; | |||||
| output_max[0] = decay * (max) + (1 - decay) * (input_max[0]); | |||||
| output_max[0] = input_max[0] < 0 ? 0 : input_max[0]; | |||||
| if (symmetric) { | |||||
| output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0]; | |||||
| output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0]; | |||||
| } | |||||
| return; | |||||
| } | |||||
| __global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max, | |||||
| const float symmetric) { | |||||
| output_min[0] = min > 0 ? 0 : min; | |||||
| output_max[0] = max < 0 ? 0 : max; | |||||
| if (symmetric) { | |||||
| output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0]; | |||||
| output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0]; | |||||
| } | |||||
| return; | |||||
| } | |||||
| __global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, | |||||
| float *output_max, int channels, int per_channel_nums, bool ema, | |||||
| float ema_decay, bool symmetric) { | |||||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { | |||||
| thrust::pair<float *, float *> sum = | |||||
| thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); | |||||
| if (ema) { | |||||
| output_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i]; | |||||
| output_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i]; | |||||
| } else { | |||||
| output_min[i] = sum.first[0]; | |||||
| output_max[i] = sum.second[0]; | |||||
| } | |||||
| output_min[i] = input_min[i] > 0 ? 0 : input_min[i]; | |||||
| output_max[i] = input_max[i] < 0 ? 0 : input_max[i]; | |||||
| if (symmetric) { | |||||
| output_max[i] = abs(output_min[i]) < output_max[i] ? output_max[i] : -output_min[i]; | |||||
| output_min[i] = abs(output_min[i]) < output_max[i] ? -output_max[i] : output_min[i]; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, | |||||
| const int total_num, const int channel_num, const float ema_decay, const bool ema, | |||||
| const bool symmetric, cudaStream_t cuda_stream) { | |||||
| int per_channel_num = total_num / channel_num; | |||||
| UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>( | |||||
| input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay, symmetric); | |||||
| return; | |||||
| } | |||||
| void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, | |||||
| const int total_num, const float ema_decay, const bool ema, const bool symmetric, | |||||
| cudaStream_t cuda_stream) { | |||||
| float minel = 0.f; | |||||
| float maxel = 0.f; | |||||
| auto policy = thrust::cuda::par.on(cuda_stream); | |||||
| thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple; | |||||
| tuple = | |||||
| thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + total_num); | |||||
| minel = tuple.first[0]; | |||||
| maxel = tuple.second[0]; | |||||
| if (ema) { | |||||
| UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel, | |||||
| maxel, ema_decay, symmetric); | |||||
| } else { | |||||
| UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel, symmetric); | |||||
| } | |||||
| return; | |||||
| } | |||||
| @@ -0,0 +1,30 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ | |||||
| #include "device/gpu/cuda_common.h" | |||||
| void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, | |||||
| const int total_num, const int channel_num, const float ema_decay, const bool ema, | |||||
| const bool symmetric, cudaStream_t cuda_stream); | |||||
| void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, | |||||
| const int size, const float ema_decay, const bool ema, const bool symmetric, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h" | |||||
| #include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh" | |||||
| #include "kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h" | |||||
| #include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh" | |||||
| #include <thrust/extrema.h> | #include <thrust/extrema.h> | ||||
| #include <thrust/pair.h> | #include <thrust/pair.h> | ||||
| #include <thrust/device_vector.h> | #include <thrust/device_vector.h> | ||||
| @@ -25,21 +25,15 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel() | FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel() | ||||
| : input_size_(0), | : input_size_(0), | ||||
| min_size_(0), | |||||
| max_size_(0), | |||||
| output_size_(0), | |||||
| workspace_size_(0), | |||||
| num_channels_(0), | |||||
| num_bits_(0), | num_bits_(0), | ||||
| quant_min_(0), | |||||
| quant_max_(0), | |||||
| quant_delay_(0), | |||||
| ema_(false), | |||||
| ema_decay_(0), | |||||
| global_step_(0), | |||||
| training_(false), | training_(false), | ||||
| channel_out_(0), | |||||
| symmetric_(false), | |||||
| narrow_range_(false), | narrow_range_(false), | ||||
| symmetric_(false) {} | |||||
| quant_delay_(0), | |||||
| quant_min_(0), | |||||
| quant_max_(0), | |||||
| global_step_(0) {} | |||||
| const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } | const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } | ||||
| @@ -60,90 +54,56 @@ bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| // get attribute | |||||
| num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | ||||
| ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); | |||||
| ema_decay_ = 1.0 - GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); | |||||
| training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); | |||||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||||
| quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); | |||||
| if (num_bits_ <= 2 || num_bits_ >= 16) { | if (num_bits_ <= 2 || num_bits_ >= 16) { | ||||
| MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16."; | MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); | |||||
| if (quant_delay_ < 0) { | if (quant_delay_ < 0) { | ||||
| MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0."; | MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); | |||||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||||
| if (symmetric_) { | |||||
| quant_min_ = 0 - (1 << (num_bits_ - 1)); | |||||
| quant_max_ = (1 << (num_bits_ - 1)) - 1; | |||||
| } else { | |||||
| quant_min_ = 0; | |||||
| quant_max_ = (1 << num_bits_) - 1; | |||||
| } | |||||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||||
| // quant min and max value | |||||
| quant_min_ = 0; | |||||
| quant_max_ = (1 << num_bits_) - 1; | |||||
| if (narrow_range_) { | if (narrow_range_) { | ||||
| quant_min_++; | quant_min_++; | ||||
| } | } | ||||
| // shape info for gpu | // shape info for gpu | ||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| channel_out_ = SizeToInt(input_shape[0]); | |||||
| min_size_ = sizeof(float) * channel_out_; | |||||
| max_size_ = sizeof(float) * channel_out_; | |||||
| num_channels_ = SizeToInt(input_shape[0]); | |||||
| input_size_ = sizeof(float); | input_size_ = sizeof(float); | ||||
| for (size_t i = 0; i < input_shape.size(); i++) { | for (size_t i = 0; i < input_shape.size(); i++) { | ||||
| input_size_ *= input_shape[i]; | input_size_ *= input_shape[i]; | ||||
| } | } | ||||
| output_size_ = input_size_; | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| void FakeQuantPerChannelGpuKernel::InitSizeLists() { | void FakeQuantPerChannelGpuKernel::InitSizeLists() { | ||||
| input_size_list_.push_back(input_size_); // input in tensor | |||||
| input_size_list_.push_back(min_size_); // min one scalar | |||||
| input_size_list_.push_back(max_size_); // max on scalar | |||||
| output_size_list_.push_back(output_size_); // output in tensor | |||||
| workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel | |||||
| workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel | |||||
| workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel | |||||
| } | |||||
| void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min, | |||||
| float *input_max, float *d_nudge_min, float *d_nudge_max, | |||||
| float *d_scale, void *stream_ptr) { | |||||
| // calculate the input min and max according by the parameter ema and ema_decay. | |||||
| CalMinMaxPerChannel(input, input_min, input_max, input_size_ / sizeof(float), channel_out_, ema_decay_, ema_, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| // control flow for quant_delay | |||||
| if (global_step_ >= quant_delay_) { | |||||
| // real launch | |||||
| CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, | |||||
| d_scale, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| } else { | |||||
| CHECK_CUDA_RET_WITH_ERROR( | |||||
| cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Copy gpu memory failed."); | |||||
| } | |||||
| global_step_++; | |||||
| input_size_list_.push_back(input_size_); // input in tensor | |||||
| input_size_list_.push_back(sizeof(float) * num_channels_); // min one scalar | |||||
| input_size_list_.push_back(sizeof(float) * num_channels_); // max on scalar | |||||
| output_size_list_.push_back(input_size_); // output in tensor | |||||
| workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel | |||||
| workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel | |||||
| workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel | |||||
| } | } | ||||
| void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForInfer(float *input, float *output, float *input_min, | |||||
| float *input_max, float *d_nudge_min, float *d_nudge_max, | |||||
| float *d_scale, void *stream_ptr) { | |||||
| // real launch | |||||
| CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, | |||||
| void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, | |||||
| float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) { | |||||
| CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, d_scale, | |||||
| CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale, | |||||
| symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr)); | symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| } | } | ||||
| @@ -155,9 +115,9 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | |||||
| float *input = GetDeviceAddress<float>(inputs, 0); | float *input = GetDeviceAddress<float>(inputs, 0); | ||||
| float *input_min = GetDeviceAddress<float>(inputs, 1); | float *input_min = GetDeviceAddress<float>(inputs, 1); | ||||
| float *input_max = GetDeviceAddress<float>(inputs, 2); | float *input_max = GetDeviceAddress<float>(inputs, 2); | ||||
| float *d_scale = GetDeviceAddress<float>(workspace, 0); | |||||
| float *d_nudge_min = GetDeviceAddress<float>(workspace, 1); | |||||
| float *d_nudge_max = GetDeviceAddress<float>(workspace, 2); | |||||
| float *scale = GetDeviceAddress<float>(workspace, 0); | |||||
| float *nudge_min = GetDeviceAddress<float>(workspace, 1); | |||||
| float *nudge_max = GetDeviceAddress<float>(workspace, 2); | |||||
| if (input == nullptr) { | if (input == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null."; | MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null."; | ||||
| @@ -167,9 +127,16 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | |||||
| } | } | ||||
| if (training_) { | if (training_) { | ||||
| CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr); | |||||
| if (global_step_ >= quant_delay_) { | |||||
| CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); | |||||
| } else { | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Copy gpu memory failed."); | |||||
| } | |||||
| global_step_++; | |||||
| } else { | } else { | ||||
| CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr); | |||||
| CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); | |||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -39,31 +39,23 @@ class FakeQuantPerChannelGpuKernel : public GpuKernel { | |||||
| void InitSizeLists() override; | void InitSizeLists() override; | ||||
| private: | private: | ||||
| void CalFakeQuantizeForTraining(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min, | |||||
| float *d_nudge_max, float *d_scale, void *stream_ptr); | |||||
| void CalFakeQuantizeForInfer(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min, | |||||
| float *d_nudge_max, float *d_scale, void *stream_ptr); | |||||
| void CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, float *nudge_min, | |||||
| float *nudge_max, float *scale, void *stream_ptr); | |||||
| size_t input_size_; | size_t input_size_; | ||||
| size_t min_size_; | |||||
| size_t max_size_; | |||||
| size_t output_size_; | |||||
| size_t workspace_size_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| int num_channels_; | |||||
| int num_bits_; | int num_bits_; | ||||
| bool training_; | |||||
| bool symmetric_; | |||||
| bool narrow_range_; | |||||
| int quant_delay_; | |||||
| float quant_min_; | float quant_min_; | ||||
| float quant_max_; | float quant_max_; | ||||
| int quant_delay_; | |||||
| bool ema_; | |||||
| float ema_decay_; | |||||
| int global_step_; | int global_step_; | ||||
| bool training_; | |||||
| int channel_out_; | |||||
| bool narrow_range_; | |||||
| bool symmetric_; | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,21 +14,17 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.h" | |||||
| #include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh" | |||||
| #include "kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h" | |||||
| #include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel() | FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel() | ||||
| : input_size_(0), | : input_size_(0), | ||||
| min_size_(0), | |||||
| max_size_(0), | |||||
| output_size_(0), | |||||
| workspace_size_(0), | |||||
| num_bits_(0), | num_bits_(0), | ||||
| quant_min_(0), | quant_min_(0), | ||||
| quant_max_(0), | quant_max_(0), | ||||
| channel_out_(0), | |||||
| num_channels_(0), | |||||
| quant_delay_(0), | quant_delay_(0), | ||||
| global_step_(0), | global_step_(0), | ||||
| narrow_range_(false), | narrow_range_(false), | ||||
| @@ -64,42 +60,34 @@ bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | ||||
| if (symmetric_) { | |||||
| quant_min_ = 0 - (1 << (num_bits_ - 1)); | |||||
| quant_max_ = (1 << (num_bits_ - 1)) - 1; | |||||
| } else { | |||||
| quant_min_ = 0; | |||||
| quant_max_ = (1 << num_bits_) - 1; | |||||
| } | |||||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | ||||
| // quant min and max value | |||||
| quant_min_ = 0; | |||||
| quant_max_ = (1 << num_bits_) - 1; | |||||
| if (narrow_range_) { | if (narrow_range_) { | ||||
| quant_min_++; | quant_min_++; | ||||
| } | } | ||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| channel_out_ = SizeToInt(input_shape[0]); | |||||
| min_size_ = sizeof(float) * channel_out_; | |||||
| max_size_ = sizeof(float) * channel_out_; | |||||
| num_channels_ = SizeToInt(input_shape[0]); | |||||
| input_size_ = sizeof(float); | input_size_ = sizeof(float); | ||||
| for (size_t i = 0; i < input_shape.size(); i++) { | for (size_t i = 0; i < input_shape.size(); i++) { | ||||
| input_size_ *= input_shape[i]; | input_size_ *= input_shape[i]; | ||||
| } | } | ||||
| output_size_ = input_size_; | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| void FakeQuantPerChannelGradGpuKernel::InitSizeLists() { | void FakeQuantPerChannelGradGpuKernel::InitSizeLists() { | ||||
| input_size_list_.push_back(input_size_); // gradient | |||||
| input_size_list_.push_back(input_size_); // input | |||||
| input_size_list_.push_back(min_size_); // min | |||||
| input_size_list_.push_back(max_size_); // max | |||||
| output_size_list_.push_back(output_size_); | |||||
| workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel | |||||
| workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel | |||||
| workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel | |||||
| input_size_list_.push_back(input_size_); // gradient | |||||
| input_size_list_.push_back(input_size_); // input | |||||
| input_size_list_.push_back(sizeof(float) * num_channels_); // min | |||||
| input_size_list_.push_back(sizeof(float) * num_channels_); // max | |||||
| output_size_list_.push_back(input_size_); // output | |||||
| workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel | |||||
| workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel | |||||
| workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel | |||||
| } | } | ||||
| bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | ||||
| @@ -111,9 +99,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp | |||||
| float *input = GetDeviceAddress<float>(inputs, 1); | float *input = GetDeviceAddress<float>(inputs, 1); | ||||
| float *input_min = GetDeviceAddress<float>(inputs, 2); | float *input_min = GetDeviceAddress<float>(inputs, 2); | ||||
| float *input_max = GetDeviceAddress<float>(inputs, 3); | float *input_max = GetDeviceAddress<float>(inputs, 3); | ||||
| float *d_scale = GetDeviceAddress<float>(workspace, 0); | |||||
| float *d_nudge_min = GetDeviceAddress<float>(workspace, 1); | |||||
| float *d_nudge_max = GetDeviceAddress<float>(workspace, 2); | |||||
| float *scale = GetDeviceAddress<float>(workspace, 0); | |||||
| float *nudge_min = GetDeviceAddress<float>(workspace, 1); | |||||
| float *nudge_max = GetDeviceAddress<float>(workspace, 2); | |||||
| if (gradient == nullptr) { | if (gradient == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null"; | MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null"; | ||||
| @@ -130,9 +118,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp | |||||
| int total_size = input_size_ / sizeof(float); | int total_size = input_size_ / sizeof(float); | ||||
| if (global_step_ >= quant_delay_) { | if (global_step_ >= quant_delay_) { | ||||
| CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, | |||||
| CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, channel_out_, d_nudge_min, d_nudge_max, | |||||
| CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| } else { | } else { | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, | CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, | ||||
| @@ -40,10 +40,6 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel { | |||||
| private: | private: | ||||
| size_t input_size_; | size_t input_size_; | ||||
| size_t min_size_; | |||||
| size_t max_size_; | |||||
| size_t output_size_; | |||||
| size_t workspace_size_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| @@ -51,7 +47,7 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel { | |||||
| int num_bits_; | int num_bits_; | ||||
| float quant_min_; | float quant_min_; | ||||
| float quant_max_; | float quant_max_; | ||||
| int channel_out_; | |||||
| int num_channels_; | |||||
| int quant_delay_; | int quant_delay_; | ||||
| int global_step_; | int global_step_; | ||||
| bool narrow_range_; | bool narrow_range_; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "kernel/gpu/quant/fake_quant_gpu_kernel.h" | |||||
| #include "kernel/gpu/cuda_impl/fake_quant_impl.cuh" | |||||
| #include "kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h" | |||||
| #include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh" | |||||
| #include <thrust/extrema.h> | #include <thrust/extrema.h> | ||||
| #include <thrust/pair.h> | #include <thrust/pair.h> | ||||
| #include <thrust/device_vector.h> | #include <thrust/device_vector.h> | ||||
| @@ -23,31 +23,25 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| FakeQuantGpuKernel::FakeQuantGpuKernel() | |||||
| FakeQuantPerLayerGpuKernel::FakeQuantPerLayerGpuKernel() | |||||
| : input_size_(0), | : input_size_(0), | ||||
| min_size_(0), | |||||
| max_size_(0), | |||||
| output_size_(0), | |||||
| workspace_size_(0), | |||||
| num_bits_(0), | |||||
| quant_min_(0), | quant_min_(0), | ||||
| quant_max_(0), | quant_max_(0), | ||||
| quant_num_(0), | |||||
| quant_delay_(0), | |||||
| ema_(false), | |||||
| ema_decay_(0), | |||||
| quant_num_(1), | |||||
| global_step_(0), | global_step_(0), | ||||
| num_bits_(0), | |||||
| quant_delay_(0), | |||||
| training_(false), | training_(false), | ||||
| narrow_range_(false), | narrow_range_(false), | ||||
| symmetric_(false) {} | symmetric_(false) {} | ||||
| const std::vector<size_t> &FakeQuantGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||||
| const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||||
| const std::vector<size_t> &FakeQuantGpuKernel::GetOutputSizeList() const { return output_size_list_; } | |||||
| const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; } | |||||
| const std::vector<size_t> &FakeQuantGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||||
| const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||||
| bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 3) { | if (input_num != 3) { | ||||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; | MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; | ||||
| @@ -59,95 +53,73 @@ bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | ||||
| ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); | |||||
| ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); | |||||
| quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); | |||||
| training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); | training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); | ||||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||||
| if (num_bits_ <= 2 || num_bits_ >= 16) { | if (num_bits_ <= 2 || num_bits_ >= 16) { | ||||
| MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; | MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; | ||||
| } | } | ||||
| quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); | |||||
| if (quant_delay_ < 0) { | if (quant_delay_ < 0) { | ||||
| MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0."; | MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0."; | ||||
| } | } | ||||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||||
| if (symmetric_) { | |||||
| quant_min_ = 0 - (1 << (num_bits_ - 1)); | |||||
| quant_max_ = (1 << (num_bits_ - 1)) - 1; | |||||
| } else { | |||||
| quant_min_ = 0; | |||||
| quant_max_ = (1 << num_bits_) - 1; | |||||
| } | |||||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||||
| // quant min and max value | |||||
| quant_min_ = 0; | |||||
| quant_max_ = (1 << num_bits_) - 1; | |||||
| if (narrow_range_) { | if (narrow_range_) { | ||||
| quant_min_++; | quant_min_++; | ||||
| } | } | ||||
| if (quant_num_ == 0) { | |||||
| quant_num_ = 1; | |||||
| } | |||||
| // init size | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| for (size_t i = 0; i < input_shape.size(); ++i) { | for (size_t i = 0; i < input_shape.size(); ++i) { | ||||
| quant_num_ *= SizeToInt(input_shape[i]); | quant_num_ *= SizeToInt(input_shape[i]); | ||||
| } | } | ||||
| input_size_ = sizeof(float); | input_size_ = sizeof(float); | ||||
| min_size_ = sizeof(float); | |||||
| max_size_ = sizeof(float); | |||||
| for (size_t i = 0; i < input_shape.size(); i++) { | for (size_t i = 0; i < input_shape.size(); i++) { | ||||
| input_size_ *= input_shape[i]; | input_size_ *= input_shape[i]; | ||||
| } | } | ||||
| output_size_ = input_size_; | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| void FakeQuantGpuKernel::InitSizeLists() { | |||||
| input_size_list_.push_back(input_size_); // input | |||||
| input_size_list_.push_back(min_size_); // min | |||||
| input_size_list_.push_back(max_size_); // max | |||||
| output_size_list_.push_back(output_size_); | |||||
| workspace_size_list_.push_back(workspace_size_); | |||||
| void FakeQuantPerLayerGpuKernel::InitSizeLists() { | |||||
| input_size_list_.push_back(input_size_); // x | |||||
| input_size_list_.push_back(sizeof(float)); // min | |||||
| input_size_list_.push_back(sizeof(float)); // max | |||||
| output_size_list_.push_back(input_size_); // y | |||||
| workspace_size_list_.push_back(sizeof(float)); // scale | |||||
| workspace_size_list_.push_back(sizeof(float)); // nudge_min | |||||
| workspace_size_list_.push_back(sizeof(float)); // nudge_max | |||||
| } | } | ||||
| bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||||
| bool FakeQuantPerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||||
| float *output = GetDeviceAddress<float>(outputs, 0); | float *output = GetDeviceAddress<float>(outputs, 0); | ||||
| float *input = GetDeviceAddress<float>(inputs, 0); | float *input = GetDeviceAddress<float>(inputs, 0); | ||||
| float *input_min = GetDeviceAddress<float>(inputs, 1); | float *input_min = GetDeviceAddress<float>(inputs, 1); | ||||
| float *input_max = GetDeviceAddress<float>(inputs, 2); | float *input_max = GetDeviceAddress<float>(inputs, 2); | ||||
| float *scale = GetDeviceAddress<float>(workspace, 0); | |||||
| float *nudge_min = GetDeviceAddress<float>(workspace, 1); | |||||
| float *nudge_max = GetDeviceAddress<float>(workspace, 2); | |||||
| if (input == nullptr) { | if (input == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input x is null."; | |||||
| } | |||||
| if (input_min == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input min is null."; | |||||
| MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input x is null."; | |||||
| } | } | ||||
| if (input_max == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input max is null."; | |||||
| if (input_min == nullptr || input_max == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input min or input max is null."; | |||||
| } | } | ||||
| // Allocate space for device copies | |||||
| int size = sizeof(float); | |||||
| float *d_scale = nullptr; | |||||
| float *d_nudge_min = nullptr; | |||||
| float *d_nudge_max = nullptr; | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed"); | |||||
| if (training_) { | if (training_) { | ||||
| // calculate the input min and max according by the parameter ema and ema_decay. | |||||
| CalMinMax(input, input_min, input_max, quant_num_, ema_decay_, ema_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| // control flow for quant_delay | // control flow for quant_delay | ||||
| if (global_step_ >= quant_delay_) { | if (global_step_ >= quant_delay_) { | ||||
| // real launch | // real launch | ||||
| CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, | |||||
| CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_, | |||||
| CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| } else { | } else { | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, | CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, | ||||
| @@ -157,20 +129,15 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std | |||||
| global_step_++; | global_step_++; | ||||
| } else { | } else { | ||||
| // real launch | // real launch | ||||
| CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, | |||||
| CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_, | |||||
| CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| } | } | ||||
| // Cleanup | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); | |||||
| return true; | return true; | ||||
| } | } | ||||
| MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantGpuKernel) | |||||
| MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantPerLayerGpuKernel) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "kernel/gpu/gpu_kernel.h" | #include "kernel/gpu/gpu_kernel.h" | ||||
| @@ -23,10 +23,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| class FakeQuantGpuKernel : public GpuKernel { | |||||
| class FakeQuantPerLayerGpuKernel : public GpuKernel { | |||||
| public: | public: | ||||
| FakeQuantGpuKernel(); | |||||
| ~FakeQuantGpuKernel() = default; | |||||
| FakeQuantPerLayerGpuKernel(); | |||||
| ~FakeQuantPerLayerGpuKernel() = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override; | const std::vector<size_t> &GetInputSizeList() const override; | ||||
| const std::vector<size_t> &GetOutputSizeList() const override; | const std::vector<size_t> &GetOutputSizeList() const override; | ||||
| @@ -40,22 +40,16 @@ class FakeQuantGpuKernel : public GpuKernel { | |||||
| private: | private: | ||||
| size_t input_size_; | size_t input_size_; | ||||
| size_t min_size_; | |||||
| size_t max_size_; | |||||
| size_t output_size_; | |||||
| size_t workspace_size_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| int num_bits_; | |||||
| float quant_min_; | float quant_min_; | ||||
| float quant_max_; | float quant_max_; | ||||
| int quant_num_; | int quant_num_; | ||||
| int quant_delay_; | |||||
| bool ema_; | |||||
| float ema_decay_; | |||||
| int global_step_; | int global_step_; | ||||
| int num_bits_; | |||||
| int quant_delay_; | |||||
| bool training_; | bool training_; | ||||
| bool narrow_range_; | bool narrow_range_; | ||||
| bool symmetric_; | bool symmetric_; | ||||
| @@ -63,4 +57,4 @@ class FakeQuantGpuKernel : public GpuKernel { | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_ | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ | |||||
| @@ -14,33 +14,30 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "kernel/gpu/quant/fake_quant_grad_gpu_kernel.h" | |||||
| #include "kernel/gpu/cuda_impl/fake_quant_impl.cuh" | |||||
| #include "kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h" | |||||
| #include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| FakeQuantGradGpuKernel::FakeQuantGradGpuKernel() | |||||
| FakeQuantPerLayerGradGpuKernel::FakeQuantPerLayerGradGpuKernel() | |||||
| : input_size_(0), | : input_size_(0), | ||||
| min_size_(0), | |||||
| max_size_(0), | |||||
| output_size_(0), | |||||
| workspace_size_(0), | workspace_size_(0), | ||||
| num_bits_(0), | num_bits_(0), | ||||
| quant_min_(0), | quant_min_(0), | ||||
| quant_max_(0), | quant_max_(0), | ||||
| quant_size_(0), | |||||
| quant_num_(1), | |||||
| quant_delay_(0), | quant_delay_(0), | ||||
| global_step_(0), | global_step_(0), | ||||
| narrow_range_(false), | narrow_range_(false), | ||||
| symmetric_(false) {} | symmetric_(false) {} | ||||
| const std::vector<size_t> &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||||
| const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||||
| const std::vector<size_t> &FakeQuantGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } | |||||
| const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } | |||||
| const std::vector<size_t> &FakeQuantGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||||
| const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||||
| bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| bool FakeQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 4) { | if (input_num != 4) { | ||||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; | MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; | ||||
| @@ -62,87 +59,66 @@ bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | ||||
| if (symmetric_) { | |||||
| quant_min_ = 0 - (1 << (num_bits_ - 1)); | |||||
| quant_max_ = (1 << (num_bits_ - 1)) - 1; | |||||
| } else { | |||||
| quant_min_ = 0; | |||||
| quant_max_ = (1 << num_bits_) - 1; | |||||
| } | |||||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | ||||
| // quant min and max value | |||||
| quant_min_ = 0; | |||||
| quant_max_ = (1 << num_bits_) - 1; | |||||
| if (narrow_range_) { | if (narrow_range_) { | ||||
| quant_min_++; | quant_min_++; | ||||
| } | } | ||||
| if (quant_size_ == 0) { | |||||
| quant_size_ = 1; | |||||
| } | |||||
| // init size | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| for (size_t i = 0; i < input_shape.size(); ++i) { | for (size_t i = 0; i < input_shape.size(); ++i) { | ||||
| quant_size_ *= SizeToInt(input_shape[i]); | |||||
| quant_num_ *= SizeToInt(input_shape[i]); | |||||
| } | } | ||||
| input_size_ = sizeof(float); | input_size_ = sizeof(float); | ||||
| min_size_ = sizeof(float); | |||||
| max_size_ = sizeof(float); | |||||
| for (size_t i = 0; i < input_shape.size(); i++) { | for (size_t i = 0; i < input_shape.size(); i++) { | ||||
| input_size_ *= input_shape[i]; | input_size_ *= input_shape[i]; | ||||
| } | } | ||||
| output_size_ = input_size_; | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| void FakeQuantGradGpuKernel::InitSizeLists() { | |||||
| input_size_list_.push_back(input_size_); // gradient | |||||
| input_size_list_.push_back(input_size_); // input | |||||
| input_size_list_.push_back(min_size_); // min | |||||
| input_size_list_.push_back(max_size_); // max | |||||
| output_size_list_.push_back(output_size_); | |||||
| void FakeQuantPerLayerGradGpuKernel::InitSizeLists() { | |||||
| input_size_list_.push_back(input_size_); // gradient | |||||
| input_size_list_.push_back(input_size_); // input | |||||
| input_size_list_.push_back(sizeof(float)); // min | |||||
| input_size_list_.push_back(sizeof(float)); // max | |||||
| output_size_list_.push_back(input_size_); // output | |||||
| workspace_size_list_.push_back(sizeof(float)); // scale | |||||
| workspace_size_list_.push_back(sizeof(float)); // nudge_min | |||||
| workspace_size_list_.push_back(sizeof(float)); // nudge_max | |||||
| } | } | ||||
| bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||||
| bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | |||||
| const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||||
| float *output = GetDeviceAddress<float>(outputs, 0); | float *output = GetDeviceAddress<float>(outputs, 0); | ||||
| float *gradient = GetDeviceAddress<float>(inputs, 0); | float *gradient = GetDeviceAddress<float>(inputs, 0); | ||||
| float *input = GetDeviceAddress<float>(inputs, 1); | float *input = GetDeviceAddress<float>(inputs, 1); | ||||
| float *input_min = GetDeviceAddress<float>(inputs, 2); | float *input_min = GetDeviceAddress<float>(inputs, 2); | ||||
| float *input_max = GetDeviceAddress<float>(inputs, 3); | float *input_max = GetDeviceAddress<float>(inputs, 3); | ||||
| float *scale = GetDeviceAddress<float>(workspace, 0); | |||||
| float *nudge_min = GetDeviceAddress<float>(workspace, 1); | |||||
| float *nudge_max = GetDeviceAddress<float>(workspace, 2); | |||||
| if (gradient == nullptr) { | if (gradient == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel gradient is null"; | |||||
| MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel gradient is null"; | |||||
| } | } | ||||
| if (input == nullptr) { | if (input == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input is null."; | |||||
| MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input is null."; | |||||
| } | } | ||||
| if (input_min == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input min is null."; | |||||
| } | |||||
| if (input_max == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input max is null."; | |||||
| if (input_min == nullptr || input_max == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input min or max is null."; | |||||
| } | } | ||||
| if (global_step_ >= quant_delay_) { | if (global_step_ >= quant_delay_) { | ||||
| float *d_scale = nullptr; | |||||
| float *d_nudge_min = nullptr; | |||||
| float *d_nudge_max = nullptr; | |||||
| int size = sizeof(float); | |||||
| // Allocate space for device copies | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed"); | |||||
| CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, | |||||
| CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| CalFakeQuantizeGrad(input, gradient, output, quant_size_, d_nudge_min, d_nudge_max, | |||||
| CalFakeQuantizeGrad(input, gradient, output, quant_num_, nudge_min, nudge_max, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| // Cleanup | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); | |||||
| } else { | } else { | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, | CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, | ||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | reinterpret_cast<cudaStream_t>(stream_ptr)), | ||||
| @@ -152,6 +128,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const | |||||
| return true; | return true; | ||||
| } | } | ||||
| MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantGradGpuKernel) | |||||
| MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantPerLayerGradGpuKernel) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ | |||||
| #include <vector> | #include <vector> | ||||
| #include "kernel/gpu/gpu_kernel.h" | #include "kernel/gpu/gpu_kernel.h" | ||||
| @@ -23,10 +23,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| class FakeQuantGradGpuKernel : public GpuKernel { | |||||
| class FakeQuantPerLayerGradGpuKernel : public GpuKernel { | |||||
| public: | public: | ||||
| FakeQuantGradGpuKernel(); | |||||
| ~FakeQuantGradGpuKernel() = default; | |||||
| FakeQuantPerLayerGradGpuKernel(); | |||||
| ~FakeQuantPerLayerGradGpuKernel() = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override; | const std::vector<size_t> &GetInputSizeList() const override; | ||||
| const std::vector<size_t> &GetOutputSizeList() const override; | const std::vector<size_t> &GetOutputSizeList() const override; | ||||
| @@ -40,9 +40,6 @@ class FakeQuantGradGpuKernel : public GpuKernel { | |||||
| private: | private: | ||||
| size_t input_size_; | size_t input_size_; | ||||
| size_t min_size_; | |||||
| size_t max_size_; | |||||
| size_t output_size_; | |||||
| size_t workspace_size_; | size_t workspace_size_; | ||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| @@ -51,7 +48,7 @@ class FakeQuantGradGpuKernel : public GpuKernel { | |||||
| int num_bits_; | int num_bits_; | ||||
| float quant_min_; | float quant_min_; | ||||
| float quant_max_; | float quant_max_; | ||||
| int quant_size_; | |||||
| int quant_num_; | |||||
| int quant_delay_; | int quant_delay_; | ||||
| int global_step_; | int global_step_; | ||||
| bool narrow_range_; | bool narrow_range_; | ||||
| @@ -60,4 +57,4 @@ class FakeQuantGradGpuKernel : public GpuKernel { | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_ | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ | |||||
| @@ -0,0 +1,119 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h" | |||||
| #include "kernel/gpu/cuda_impl/minmax_update_impl.cuh" | |||||
| #include <thrust/extrema.h> | |||||
| #include <thrust/pair.h> | |||||
| #include <thrust/device_vector.h> | |||||
| #include <cuda_runtime_api.h> | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel() | |||||
| : input_size_(0), | |||||
| num_bits_(0), | |||||
| quant_min_(0), | |||||
| quant_max_(0), | |||||
| quant_num_(1), | |||||
| ema_(false), | |||||
| ema_decay_(0), | |||||
| num_channels_(0), | |||||
| narrow_range_(false), | |||||
| symmetric_(false) {} | |||||
| const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||||
| const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; } | |||||
| const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetWorkspaceSizeList() const { | |||||
| return workspace_size_list_; | |||||
| } | |||||
| bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 3) { | |||||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 2) { | |||||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; | |||||
| } | |||||
| num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | |||||
| ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); | |||||
| ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); | |||||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||||
| if (num_bits_ <= 2 || num_bits_ >= 16) { | |||||
| MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; | |||||
| } | |||||
| // quant min and max | |||||
| quant_min_ = 0; | |||||
| quant_max_ = (1 << num_bits_) - 1; | |||||
| if (narrow_range_) { | |||||
| quant_min_++; | |||||
| } | |||||
| // init size | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| num_channels_ = SizeToInt(input_shape[0]); | |||||
| for (size_t i = 0; i < input_shape.size(); ++i) { | |||||
| quant_num_ *= SizeToInt(input_shape[i]); | |||||
| } | |||||
| input_size_ = sizeof(float); | |||||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||||
| input_size_ *= input_shape[i]; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| void MinMaxUpdatePerChannelGpuKernel::InitSizeLists() { | |||||
| input_size_list_.push_back(input_size_); // input | |||||
| input_size_list_.push_back(sizeof(float) * num_channels_); // min | |||||
| input_size_list_.push_back(sizeof(float) * num_channels_); // max | |||||
| output_size_list_.push_back(sizeof(float) * num_channels_); // output min | |||||
| output_size_list_.push_back(sizeof(float) * num_channels_); // output max | |||||
| } | |||||
| bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||||
| float *output_min = GetDeviceAddress<float>(outputs, 0); | |||||
| float *output_max = GetDeviceAddress<float>(outputs, 1); | |||||
| float *input = GetDeviceAddress<float>(inputs, 0); | |||||
| float *input_min = GetDeviceAddress<float>(inputs, 1); | |||||
| float *input_max = GetDeviceAddress<float>(inputs, 2); | |||||
| if (input == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input x is null."; | |||||
| } | |||||
| if (input_min == nullptr || input_max == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input min or input max is null."; | |||||
| } | |||||
| // calculate the input min and max according by the parameter ema and ema_decay. | |||||
| CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_, | |||||
| ema_decay_, ema_, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| MS_REG_GPU_KERNEL(MinMaxUpdatePerChannel, MinMaxUpdatePerChannelGpuKernel) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,60 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ | |||||
| #include <vector> | |||||
| #include "kernel/gpu/gpu_kernel.h" | |||||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class MinMaxUpdatePerChannelGpuKernel : public GpuKernel { | |||||
| public: | |||||
| MinMaxUpdatePerChannelGpuKernel(); | |||||
| ~MinMaxUpdatePerChannelGpuKernel() = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override; | |||||
| const std::vector<size_t> &GetOutputSizeList() const override; | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override; | |||||
| bool Init(const CNodePtr &kernel) override; | |||||
| protected: | |||||
| void InitSizeLists() override; | |||||
| private: | |||||
| size_t input_size_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| int num_bits_; | |||||
| float quant_min_; | |||||
| float quant_max_; | |||||
| int quant_num_; | |||||
| bool ema_; | |||||
| float ema_decay_; | |||||
| int num_channels_; | |||||
| bool narrow_range_; | |||||
| bool symmetric_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ | |||||
| @@ -0,0 +1,115 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h" | |||||
| #include "kernel/gpu/cuda_impl/minmax_update_impl.cuh" | |||||
| #include <thrust/extrema.h> | |||||
| #include <thrust/pair.h> | |||||
| #include <thrust/device_vector.h> | |||||
| #include <cuda_runtime_api.h> | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel() | |||||
| : input_size_(0), | |||||
| num_bits_(0), | |||||
| quant_min_(0), | |||||
| quant_max_(0), | |||||
| quant_num_(1), | |||||
| ema_(false), | |||||
| ema_decay_(0), | |||||
| narrow_range_(false), | |||||
| symmetric_(false) {} | |||||
| const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||||
| const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; } | |||||
| const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||||
| bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 3) { | |||||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 2) { | |||||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; | |||||
| } | |||||
| num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | |||||
| ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); | |||||
| ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); | |||||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||||
| if (num_bits_ <= 2 || num_bits_ >= 16) { | |||||
| MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; | |||||
| } | |||||
| // quant min and max | |||||
| quant_min_ = 0; | |||||
| quant_max_ = (1 << num_bits_) - 1; | |||||
| if (narrow_range_) { | |||||
| quant_min_++; | |||||
| } | |||||
| // init size | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| for (size_t i = 0; i < input_shape.size(); ++i) { | |||||
| quant_num_ *= SizeToInt(input_shape[i]); | |||||
| } | |||||
| input_size_ = sizeof(float); | |||||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||||
| input_size_ *= input_shape[i]; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| void MinMaxUpdatePerLayerGpuKernel::InitSizeLists() { | |||||
| input_size_list_.push_back(input_size_); // input | |||||
| input_size_list_.push_back(sizeof(float)); // input min | |||||
| input_size_list_.push_back(sizeof(float)); // input max | |||||
| output_size_list_.push_back(sizeof(float)); // output min | |||||
| output_size_list_.push_back(sizeof(float)); // output max | |||||
| } | |||||
| bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) { | |||||
| float *output_min = GetDeviceAddress<float>(outputs, 0); | |||||
| float *output_max = GetDeviceAddress<float>(outputs, 1); | |||||
| float *input = GetDeviceAddress<float>(inputs, 0); | |||||
| float *input_min = GetDeviceAddress<float>(inputs, 1); | |||||
| float *input_max = GetDeviceAddress<float>(inputs, 2); | |||||
| if (input == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input x is null."; | |||||
| } | |||||
| if (input_min == nullptr || input_max == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null."; | |||||
| } | |||||
| CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, symmetric_, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| MS_REG_GPU_KERNEL(MinMaxUpdatePerLayer, MinMaxUpdatePerLayerGpuKernel) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ | |||||
| #include <vector> | |||||
| #include "kernel/gpu/gpu_kernel.h" | |||||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class MinMaxUpdatePerLayerGpuKernel : public GpuKernel { | |||||
| public: | |||||
| MinMaxUpdatePerLayerGpuKernel(); | |||||
| ~MinMaxUpdatePerLayerGpuKernel() = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override; | |||||
| const std::vector<size_t> &GetOutputSizeList() const override; | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override; | |||||
| bool Init(const CNodePtr &kernel) override; | |||||
| protected: | |||||
| void InitSizeLists() override; | |||||
| private: | |||||
| size_t input_size_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| int num_bits_; | |||||
| float quant_min_; | |||||
| float quant_max_; | |||||
| int quant_num_; | |||||
| bool ema_; | |||||
| float ema_decay_; | |||||
| bool narrow_range_; | |||||
| bool symmetric_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ | |||||
| @@ -28,7 +28,7 @@ message LineageEvent { | |||||
| oneof what { | oneof what { | ||||
| // An event file was started, with the specified version. | // An event file was started, with the specified version. | ||||
| // Now version is "Mindspore.Event:1" | |||||
| // Now version is "MindSpore.Event:1" | |||||
| string version = 3; | string version = 3; | ||||
| // Train lineage | // Train lineage | ||||
| @@ -32,7 +32,7 @@ message Event { | |||||
| oneof what { | oneof what { | ||||
| // An event file was started, with the specified version. | // An event file was started, with the specified version. | ||||
| // Now version is "Mindspore.Event:1" | |||||
| // Now version is "MindSpore.Event:1" | |||||
| string version = 3; | string version = 3; | ||||
| // GraphDef. | // GraphDef. | ||||
| @@ -32,7 +32,7 @@ | |||||
| #include "vm/segment_runner.h" | #include "vm/segment_runner.h" | ||||
| #include "vm/backend.h" | #include "vm/backend.h" | ||||
| // mindspore namespace is the top level namespace of Mindsporeession project. | |||||
| // mindspore namespace is the top level namespace of MindSpore project. | |||||
| // Other namespace should be a sub namespace of mindspore namespace in the ME project. | // Other namespace should be a sub namespace of mindspore namespace in the ME project. | ||||
| namespace mindspore { | namespace mindspore { | ||||
| extern const char kMsVm[]; | extern const char kMsVm[]; | ||||
| @@ -12,7 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Aware quantization.""" | |||||
| """Quantization aware.""" | |||||
| from functools import partial | from functools import partial | ||||
| import numpy as np | import numpy as np | ||||
| @@ -172,7 +172,7 @@ class DenseBnAct(Cell): | |||||
| Tensor of shape :math:`(N, out\_channels)`. | Tensor of shape :math:`(N, out\_channels)`. | ||||
| Examples: | Examples: | ||||
| >>> net = nn.Dense(3, 4) | |||||
| >>> net = nn.DenseBnAct(3, 4) | |||||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | ||||
| >>> net(input) | >>> net(input) | ||||
| """ | """ | ||||
| @@ -271,7 +271,7 @@ class BatchNormFoldCell(Cell): | |||||
| class FakeQuantWithMinMax(Cell): | class FakeQuantWithMinMax(Cell): | ||||
| r""" | r""" | ||||
| Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. | |||||
| Quantization aware op. This OP provide Fake quantization observer function on data with min and max. | |||||
| Args: | Args: | ||||
| min_init (int, float): The dimension of channel or 1(layer). Default: -6. | min_init (int, float): The dimension of channel or 1(layer). Default: -6. | ||||
| @@ -338,22 +338,30 @@ class FakeQuantWithMinMax(Cell): | |||||
| # init fake quant relative op | # init fake quant relative op | ||||
| if per_channel: | if per_channel: | ||||
| quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis) | quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis) | ||||
| ema_fun = partial(Q.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis) | |||||
| ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis) | |||||
| else: | else: | ||||
| quant_fun = Q.FakeQuantPerLayer | quant_fun = Q.FakeQuantPerLayer | ||||
| ema_fun = Q.FakeQuantMinMaxPerLayerUpdate | |||||
| ema_fun = Q.MinMaxUpdatePerLayer | |||||
| if self.is_ascend: | if self.is_ascend: | ||||
| self.fake_quant = quant_fun(num_bits=self.num_bits, | self.fake_quant = quant_fun(num_bits=self.num_bits, | ||||
| symmetric=self.symmetric, | symmetric=self.symmetric, | ||||
| narrow_range=self.narrow_range) | narrow_range=self.narrow_range) | ||||
| else: | else: | ||||
| self.fake_quant = quant_fun(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=ema_decay, | |||||
| quant_delay=quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range) | |||||
| self.fake_quant_train = quant_fun(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=ema_decay, | |||||
| quant_delay=quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=True) | |||||
| self.fake_quant_infer = quant_fun(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=ema_decay, | |||||
| quant_delay=quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=False) | |||||
| self.ema_update = ema_fun(num_bits=self.num_bits, | self.ema_update = ema_fun(num_bits=self.num_bits, | ||||
| ema=self.ema, | ema=self.ema, | ||||
| ema_decay=self.ema_decay, | ema_decay=self.ema_decay, | ||||
| @@ -368,16 +376,24 @@ class FakeQuantWithMinMax(Cell): | |||||
| return s | return s | ||||
| def construct(self, x): | def construct(self, x): | ||||
| if self.is_ascend and self.training: | |||||
| min_up, max_up = self.ema_update(x, self.minq, self.maxq) | |||||
| out = self.fake_quant(x, min_up, max_up) | |||||
| P.Assign()(self.minq, min_up) | |||||
| P.Assign()(self.maxq, max_up) | |||||
| if self.is_ascend: | |||||
| if self.training: | |||||
| min_up, max_up = self.ema_update(x, self.minq, self.maxq) | |||||
| out = self.fake_quant(x, min_up, max_up) | |||||
| P.Assign()(self.minq, min_up) | |||||
| P.Assign()(self.maxq, max_up) | |||||
| else: | |||||
| out = self.fake_quant(x, self.minq, self.maxq) | |||||
| else: | else: | ||||
| out = self.fake_quant(x, self.minq, self.maxq) | |||||
| if self.training: | |||||
| min_up, max_up = self.ema_update(x, self.minq, self.maxq) | |||||
| out = self.fake_quant_train(x, min_up, max_up) | |||||
| P.Assign()(self.minq, min_up) | |||||
| P.Assign()(self.maxq, max_up) | |||||
| else: | |||||
| out = self.fake_quant_infer(x, self.minq, self.maxq) | |||||
| return out | return out | ||||
| class Conv2dBatchNormQuant(Cell): | class Conv2dBatchNormQuant(Cell): | ||||
| r""" | r""" | ||||
| 2D convolution with BatchNormal op folded layer. | 2D convolution with BatchNormal op folded layer. | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Generate bprop for aware quantization ops""" | |||||
| """Generate bprop for quantization aware ops""" | |||||
| from .. import operations as P | from .. import operations as P | ||||
| from ..operations import _quant_ops as Q | from ..operations import _quant_ops as Q | ||||
| @@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(Q.FakeQuantMinMaxPerLayerUpdate) | |||||
| @bprop_getters.register(Q.MinMaxUpdatePerLayer) | |||||
| def get_bprop_fakequant_with_minmax_per_layer_update(self): | def get_bprop_fakequant_with_minmax_per_layer_update(self): | ||||
| """Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend""" | |||||
| """Generate bprop for MinMaxUpdatePerLayer for Ascend""" | |||||
| def bprop(x, x_min, x_max, out, dout): | def bprop(x, x_min, x_max, out, dout): | ||||
| return zeros_like(x), zeros_like(x_min), zeros_like(x_max) | return zeros_like(x), zeros_like(x_min), zeros_like(x_max) | ||||
| @@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(Q.FakeQuantMinMaxPerChannelUpdate) | |||||
| @bprop_getters.register(Q.MinMaxUpdatePerChannel) | |||||
| def get_bprop_fakequant_with_minmax_per_channel_update(self): | def get_bprop_fakequant_with_minmax_per_channel_update(self): | ||||
| """Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend""" | |||||
| """Generate bprop for MinMaxUpdatePerChannel for Ascend""" | |||||
| def bprop(x, x_min, x_max, out, dout): | def bprop(x, x_min, x_max, out, dout): | ||||
| return zeros_like(x), zeros_like(x_min), zeros_like(x_max) | return zeros_like(x), zeros_like(x_min), zeros_like(x_max) | ||||
| @@ -14,7 +14,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """FakeQuantMinMaxPerChannelUpdate op""" | |||||
| """MinMaxUpdatePerChannel op""" | |||||
| import te.lang.cce | import te.lang.cce | ||||
| from te import tvm | from te import tvm | ||||
| from te.platform.fusion_manager import fusion_manager | from te.platform.fusion_manager import fusion_manager | ||||
| @@ -23,7 +23,7 @@ from topi.cce import util | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \ | |||||
| fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .async_flag(False) \ | .async_flag(False) \ | ||||
| .binfile_name("fake_quant_min_max_per_channel_update.so") \ | .binfile_name("fake_quant_min_max_per_channel_update.so") \ | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """FakeQuantMinMaxPerLayerUpdate op""" | |||||
| """MinMaxUpdatePerLayer op""" | |||||
| from functools import reduce as functools_reduce | from functools import reduce as functools_reduce | ||||
| import te.lang.cce | import te.lang.cce | ||||
| from te import tvm | from te import tvm | ||||
| @@ -23,7 +23,7 @@ from topi.cce import util | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \ | |||||
| fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .async_flag(False) \ | .async_flag(False) \ | ||||
| .binfile_name("fake_quant_minmax_update.so") \ | .binfile_name("fake_quant_minmax_update.so") \ | ||||
| @@ -48,14 +48,14 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \ | |||||
| @op_info_register(fake_quant_minmax_update_op_info) | @op_info_register(fake_quant_minmax_update_op_info) | ||||
| def _fake_quant_minmax_update_tbe(): | def _fake_quant_minmax_update_tbe(): | ||||
| """FakeQuantMinMaxPerLayerUpdate TBE register""" | |||||
| """MinMaxUpdatePerLayer TBE register""" | |||||
| return | return | ||||
| @fusion_manager.register("fake_quant_minmax_update") | @fusion_manager.register("fake_quant_minmax_update") | ||||
| def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, | def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, | ||||
| kernel_name="fake_quant_minmax_update"): | kernel_name="fake_quant_minmax_update"): | ||||
| """FakeQuantMinMaxPerLayerUpdate compute""" | |||||
| """MinMaxUpdatePerLayer compute""" | |||||
| shape = te.lang.cce.util.shape_to_list(x.shape) | shape = te.lang.cce.util.shape_to_list(x.shape) | ||||
| shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | ||||
| min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) | min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) | ||||
| @@ -25,8 +25,8 @@ __all__ = ["FakeQuantPerLayer", | |||||
| "FakeQuantPerLayerGrad", | "FakeQuantPerLayerGrad", | ||||
| "FakeQuantPerChannel", | "FakeQuantPerChannel", | ||||
| "FakeQuantPerChannelGrad", | "FakeQuantPerChannelGrad", | ||||
| "FakeQuantMinMaxPerLayerUpdate", | |||||
| "FakeQuantMinMaxPerChannelUpdate", | |||||
| "MinMaxUpdatePerLayer", | |||||
| "MinMaxUpdatePerChannel", | |||||
| "BatchNormFold", | "BatchNormFold", | ||||
| "BatchNormFoldGrad", | "BatchNormFoldGrad", | ||||
| "CorrectionMul", | "CorrectionMul", | ||||
| @@ -47,11 +47,11 @@ class FakeQuantPerLayer(PrimitiveWithInfer): | |||||
| Simulate the quantize and dequantize operations in training time. | Simulate the quantize and dequantize operations in training time. | ||||
| Args: | Args: | ||||
| num_bits (int) : Number bits for aware quantilization. Default: 8. | |||||
| num_bits (int) : Number bits for quantization aware. Default: 8. | |||||
| ema (bool): Use EMA algorithm update value min and max. Default: False. | ema (bool): Use EMA algorithm update value min and max. Default: False. | ||||
| ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | ||||
| quant_delay (int): Quantilization delay parameter. Before delay step in training time not update | quant_delay (int): Quantilization delay parameter. Before delay step in training time not update | ||||
| simulate aware quantize funcion. After delay step in training time begin simulate the aware | |||||
| simulate quantization aware funcion. After delay step in training time begin simulate the aware | |||||
| quantize funcion. Default: 0. | quantize funcion. Default: 0. | ||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | ||||
| @@ -834,12 +834,12 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer): | |||||
| return dout_type, dout_type | return dout_type, dout_type | ||||
| class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): | |||||
| class MinMaxUpdatePerLayer(PrimitiveWithInfer): | |||||
| r""" | r""" | ||||
| Update min and max value for fake quant per layer op. | Update min and max value for fake quant per layer op. | ||||
| Args: | Args: | ||||
| num_bits (int) : Number bits for aware quantilization. Default: 8. | |||||
| num_bits (int) : Number bits for quantization aware. Default: 8. | |||||
| ema (bool): Use EMA algorithm update value min and max. Default: False. | ema (bool): Use EMA algorithm update value min and max. Default: False. | ||||
| ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | ||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| @@ -858,14 +858,14 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): | |||||
| >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | ||||
| >>> min_tensor = Tensor(np.array([-6]), mstype.float32) | >>> min_tensor = Tensor(np.array([-6]), mstype.float32) | ||||
| >>> max_tensor = Tensor(np.array([6]), mstype.float32) | >>> max_tensor = Tensor(np.array([6]), mstype.float32) | ||||
| >>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) | |||||
| >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor) | |||||
| """ | """ | ||||
| support_quant_bit = [4, 7, 8] | support_quant_bit = [4, 7, 8] | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, | def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, | ||||
| training=True): | training=True): | ||||
| """init FakeQuantMinMaxPerLayerUpdate OP""" | |||||
| """init MinMaxUpdatePerLayer OP""" | |||||
| if context.get_context('device_target') == "Ascend": | if context.get_context('device_target') == "Ascend": | ||||
| from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update | from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update | ||||
| if num_bits not in self.support_quant_bit: | if num_bits not in self.support_quant_bit: | ||||
| @@ -907,12 +907,12 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): | |||||
| return min_type, max_type | return min_type, max_type | ||||
| class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): | |||||
| class MinMaxUpdatePerChannel(PrimitiveWithInfer): | |||||
| r""" | r""" | ||||
| Update min and max value for fake quant per layer op. | Update min and max value for fake quant per layer op. | ||||
| Args: | Args: | ||||
| num_bits (int) : Number bits for aware quantilization. Default: 8. | |||||
| num_bits (int) : Number bits for quantization aware. Default: 8. | |||||
| ema (bool): Use EMA algorithm update value min and max. Default: False. | ema (bool): Use EMA algorithm update value min and max. Default: False. | ||||
| ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | ||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| @@ -932,14 +932,14 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): | |||||
| >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | ||||
| >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) | >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) | ||||
| >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) | >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) | ||||
| >>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max) | |||||
| >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max) | |||||
| """ | """ | ||||
| support_quant_bit = [4, 7, 8] | support_quant_bit = [4, 7, 8] | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, | def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, | ||||
| training=True, channel_axis=1): | training=True, channel_axis=1): | ||||
| """init FakeQuantPerChannelUpdate OP for Ascend""" | |||||
| """init MinMaxUpdatePerChannel OP for Ascend""" | |||||
| if context.get_context('device_target') == "Ascend": | if context.get_context('device_target') == "Ascend": | ||||
| from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update | from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update | ||||
| if num_bits not in self.support_quant_bit: | if num_bits not in self.support_quant_bit: | ||||
| @@ -1932,7 +1932,7 @@ class Eye(PrimitiveWithInfer): | |||||
| Inputs: | Inputs: | ||||
| - **n** (int) - Number of rows of returned tensor | - **n** (int) - Number of rows of returned tensor | ||||
| - **m** (int) - Number of columns of returned tensor | - **m** (int) - Number of columns of returned tensor | ||||
| - **t** (mindspore.dtype) - Mindspore's dtype, The data type of the returned tensor. | |||||
| - **t** (mindspore.dtype) - MindSpore's dtype, The data type of the returned tensor. | |||||
| Outputs: | Outputs: | ||||
| Tensor, a tensor with ones on the diagonal and zeros elsewhere. | Tensor, a tensor with ones on the diagonal and zeros elsewhere. | ||||
| @@ -76,7 +76,7 @@ class LossMonitor(Callback): | |||||
| step_loss = np.mean(step_loss.asnumpy()) | step_loss = np.mean(step_loss.asnumpy()) | ||||
| self.losses.append(step_loss) | self.losses.append(step_loss) | ||||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num | |||||
| cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) | |||||
| if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): | if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): | ||||
| raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " | raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " | ||||
| @@ -88,6 +88,6 @@ class LossMonitor(Callback): | |||||
| print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " | print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " | ||||
| "loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format( | "loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format( | ||||
| cb_params.cur_epoch_num - 1, cb_params.epoch_num, | cb_params.cur_epoch_num - 1, cb_params.epoch_num, | ||||
| cur_step_in_epoch, cb_params.batch_num, | |||||
| cur_step_in_epoch, int(cb_params.batch_num), | |||||
| step_loss, np.mean(self.losses), | step_loss, np.mean(self.losses), | ||||
| step_mseconds), flush=True) | step_mseconds), flush=True) | ||||
| @@ -15,10 +15,10 @@ | |||||
| """ | """ | ||||
| quantization. | quantization. | ||||
| User can use aware quantization to train a model. Mindspore supports quantization aware training, | |||||
| User can use quantization aware to train a model. MindSpore supports quantization aware training, | |||||
| which models quantization errors in both the forward and backward passes using fake-quantization | which models quantization errors in both the forward and backward passes using fake-quantization | ||||
| ops. Note that the entire computation is carried out in floating point. At the end of quantization | ops. Note that the entire computation is carried out in floating point. At the end of quantization | ||||
| aware training, Mindspore provides conversion functions to convert the trained model into lower precision. | |||||
| aware training, MindSpore provides conversion functions to convert the trained model into lower precision. | |||||
| """ | """ | ||||
| from .quant import convert_quant_network | from .quant import convert_quant_network | ||||
| @@ -12,12 +12,13 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """aware quantization.""" | |||||
| """quantization aware.""" | |||||
| import copy | import copy | ||||
| import re | import re | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore.context as context | |||||
| from ... import log as logger | from ... import log as logger | ||||
| from ... import nn, ops | from ... import nn, ops | ||||
| @@ -234,7 +235,7 @@ class ConvertToQuantNetwork: | |||||
| subcell.has_act = True | subcell.has_act = True | ||||
| subcell.activation = _AddFakeQuantAfterSubCell(F.identity, | subcell.activation = _AddFakeQuantAfterSubCell(F.identity, | ||||
| num_bits=self.act_bits, | num_bits=self.act_bits, | ||||
| quant_delay=self.act_delay, | |||||
| quant_delay=self.act_qdelay, | |||||
| per_channel=self.act_channel, | per_channel=self.act_channel, | ||||
| symmetric=self.act_symmetric, | symmetric=self.act_symmetric, | ||||
| narrow_range=self.act_range) | narrow_range=self.act_range) | ||||
| @@ -403,29 +404,30 @@ def convert_quant_network(network, | |||||
| narrow_range=(False, False) | narrow_range=(False, False) | ||||
| ): | ): | ||||
| r""" | r""" | ||||
| Create aware quantizaiton training network. | |||||
| Create quantization aware training network. | |||||
| Args: | Args: | ||||
| network (Cell): Obtain a pipeline through network for saving graph summary. | network (Cell): Obtain a pipeline through network for saving graph summary. | ||||
| quant_delay (int): Number of steps after which weights and activations are quantized during | |||||
| eval. The first element represent weights and second element represent data flow. Default: [0, 0] | |||||
| quant_delay (int or tuple): Number of steps after which weights and activations are quantized during | |||||
| eval. The first element represent weights and second element represent data flow. Default: (0, 0) | |||||
| bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. | bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. | ||||
| freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0. | freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0. | ||||
| num_bits (list of int): Number of bits to use for quantizing weights and activations. The first | |||||
| element represent weights and second element represent data flow. Default: [8, 8] | |||||
| per_channel (list of bool): Quantization granularity based on layer or on channel. If `True` | |||||
| num_bits (int or tuple): Number of bits to use for quantizing weights and activations. The first | |||||
| element represent weights and second element represent data flow. Default: (8, 8) | |||||
| per_channel (int or tuple): Quantization granularity based on layer or on channel. If `True` | |||||
| then base on per channel otherwise base on per layer. The first element represent weights | then base on per channel otherwise base on per layer. The first element represent weights | ||||
| and second element represent data flow. Default: [False, False] | |||||
| symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on | |||||
| and second element represent data flow. Default: (False, False) | |||||
| symmetric (int or tuple): Quantization algorithm use symmetric or not. If `True` then base on | |||||
| symmetric otherwise base on assymmetric. The first element represent weights and second | symmetric otherwise base on assymmetric. The first element represent weights and second | ||||
| element represent data flow. Default: [False, False] | |||||
| narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base | |||||
| element represent data flow. Default: (False, False) | |||||
| narrow_range (int or tuple): Quantization algorithm use narrow range or not. If `True` then base | |||||
| on narrow range otherwise base on off narrow range. The first element represent weights and | on narrow range otherwise base on off narrow range. The first element represent weights and | ||||
| second element represent data flow. Default: [False, False] | |||||
| second element represent data flow. Default: (False, False) | |||||
| Returns: | Returns: | ||||
| Cell, Network which has change to aware quantization training network cell. | |||||
| Cell, Network which has change to quantization aware training network cell. | |||||
| """ | """ | ||||
| support_device = ["Ascend", "GPU"] | |||||
| def convert2list(name, value): | def convert2list(name, value): | ||||
| if not isinstance(value, list) and not isinstance(value, tuple): | if not isinstance(value, list) and not isinstance(value, tuple): | ||||
| value = [value] | value = [value] | ||||
| @@ -439,6 +441,9 @@ def convert_quant_network(network, | |||||
| symmetric = convert2list("symmetric", symmetric) | symmetric = convert2list("symmetric", symmetric) | ||||
| narrow_range = convert2list("narrow range", narrow_range) | narrow_range = convert2list("narrow range", narrow_range) | ||||
| if context.get_context('device_target') not in support_device: | |||||
| raise KeyError("Not support {} backend.".format(context.get_context('device_target'))) | |||||
| net = ConvertToQuantNetwork(network=network, | net = ConvertToQuantNetwork(network=network, | ||||
| quant_delay=quant_delay, | quant_delay=quant_delay, | ||||
| bn_fold=bn_fold, | bn_fold=bn_fold, | ||||
| @@ -30,7 +30,7 @@ MS_IMAGE_TENSOR_FORMAT = 'NCHW' | |||||
| # Set the Event mark | # Set the Event mark | ||||
| EVENT_FILE_NAME_MARK = ".out.events.summary." | EVENT_FILE_NAME_MARK = ".out.events.summary." | ||||
| # Set the init event of version and mark | # Set the init event of version and mark | ||||
| EVENT_FILE_INIT_VERSION_MARK = "Mindspore.Event:" | |||||
| EVENT_FILE_INIT_VERSION_MARK = "MindSpore.Event:" | |||||
| EVENT_FILE_INIT_VERSION = 1 | EVENT_FILE_INIT_VERSION = 1 | ||||
| F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max | F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max | ||||
| @@ -2,13 +2,13 @@ | |||||
| ## Description | ## Description | ||||
| Training LeNet with MNIST dataset in MindSpore with quantization aware trainging. | |||||
| Training LeNet with MNIST dataset in MindSpore with quantization aware training. | |||||
| This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware. | This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware. | ||||
| In this tutorial, you will: | In this tutorial, you will: | ||||
| 1. Train a Mindspore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`. | |||||
| 1. Train a MindSpore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`. | |||||
| 2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file. | 2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file. | ||||
| 3. Use the quantization aware model to create an actually quantized model for the Ascend inference backend. | 3. Use the quantization aware model to create an actually quantized model for the Ascend inference backend. | ||||
| 4. See the persistence of accuracy in inference backend and a 4x smaller model. To see the latency benefits on mobile, try out the Ascend inference backend examples. | 4. See the persistence of accuracy in inference backend and a 4x smaller model. To see the latency benefits on mobile, try out the Ascend inference backend examples. | ||||
| @@ -24,10 +24,10 @@ Install MindSpore base on the ascend device and GPU device from [MindSpore](http | |||||
| ```python | ```python | ||||
| pip uninstall -y mindspore-ascend | pip uninstall -y mindspore-ascend | ||||
| pip uninstall -y mindspore-gpu | pip uninstall -y mindspore-gpu | ||||
| pip install mindspore-ascend-0.4.0.whl | |||||
| pip install mindspore-ascend.whl | |||||
| ``` | ``` | ||||
| then you will get the following display | |||||
| Then you will get the following display | |||||
| ```bash | ```bash | ||||
| @@ -87,7 +87,7 @@ class LeNet5(nn.Cell): | |||||
| return x | return x | ||||
| ``` | ``` | ||||
| get the MNIST from scratch dataset. | |||||
| Get the MNIST from scratch dataset. | |||||
| ```Python | ```Python | ||||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), | ds_train = create_dataset(os.path.join(args.data_path, "train"), | ||||
| @@ -97,7 +97,7 @@ step_size = ds_train.get_dataset_size() | |||||
| ### Train model | ### Train model | ||||
| Load teh Lenet fusion network, traing network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`. | |||||
| Load the Lenet fusion network, training network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`. | |||||
| ```Python | ```Python | ||||
| # Define the network | # Define the network | ||||
| @@ -133,7 +133,7 @@ After all the following we will get the loss value of each step as following: | |||||
| >>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] | >>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] | ||||
| ``` | ``` | ||||
| To save your time, just run this command. | |||||
| Also, you can just run this command instead. | |||||
| ```python | ```python | ||||
| python train.py --data_path MNIST_Data --device_target Ascend | python train.py --data_path MNIST_Data --device_target Ascend | ||||
| @@ -165,17 +165,17 @@ Note that the resulting model is quantization aware but not quantized (e.g. the | |||||
| # define funsion network | # define funsion network | ||||
| network = LeNet5Fusion(cfg.num_classes) | network = LeNet5Fusion(cfg.num_classes) | ||||
| # load aware quantizaiton network checkpoint | |||||
| # load quantization aware network checkpoint | |||||
| param_dict = load_checkpoint(args.ckpt_path) | param_dict = load_checkpoint(args.ckpt_path) | ||||
| load_param_into_net(network, param_dict) | load_param_into_net(network, param_dict) | ||||
| # convert funsion netwrok to aware quantizaiton network | |||||
| # convert funsion netwrok to quantization aware network | |||||
| network = quant.convert_quant_network(network) | network = quant.convert_quant_network(network) | ||||
| ``` | ``` | ||||
| ### load checkpoint | ### load checkpoint | ||||
| after convert to quantization aware network, we can load the checkpoint file. | |||||
| After convert to quantization aware network, we can load the checkpoint file. | |||||
| ```python | ```python | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | ||||
| @@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||||
| ### train quantization aware model | ### train quantization aware model | ||||
| To save your time, just run this command. | |||||
| Also, you can just run this command instread. | |||||
| ```python | ```python | ||||
| python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt | python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt | ||||
| @@ -210,18 +210,18 @@ Procedure of quantization aware model evaluation is different from normal. Becau | |||||
| # define funsion network | # define funsion network | ||||
| network = LeNet5Fusion(cfg.num_classes) | network = LeNet5Fusion(cfg.num_classes) | ||||
| # load aware quantizaiton network checkpoint | |||||
| # load quantization aware network checkpoint | |||||
| param_dict = load_checkpoint(args.ckpt_path) | param_dict = load_checkpoint(args.ckpt_path) | ||||
| load_param_into_net(network, param_dict) | load_param_into_net(network, param_dict) | ||||
| # convert funsion netwrok to aware quantizaiton network | |||||
| # convert funsion netwrok to quantization aware network | |||||
| network = quant.convert_quant_network(network | network = quant.convert_quant_network(network | ||||
| ``` | ``` | ||||
| To save your time, just run this command. | |||||
| Also, you can just run this command insread. | |||||
| ```python | ```python | ||||
| python eval.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt | |||||
| python eval_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt | |||||
| ``` | ``` | ||||
| The top1 accuracy would display on shell. | The top1 accuracy would display on shell. | ||||
| @@ -50,7 +50,7 @@ if __name__ == "__main__": | |||||
| # define funsion network | # define funsion network | ||||
| network = LeNet5Fusion(cfg.num_classes) | network = LeNet5Fusion(cfg.num_classes) | ||||
| # convert funsion netwrok to aware quantizaiton network | |||||
| # convert funsion netwrok to quantization aware network | |||||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | ||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | ||||
| @@ -60,7 +60,7 @@ if __name__ == "__main__": | |||||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) | ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) | ||||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | ||||
| # load aware quantizaiton network checkpoint | |||||
| # load quantization aware network checkpoint | |||||
| param_dict = load_checkpoint(args.ckpt_path) | param_dict = load_checkpoint(args.ckpt_path) | ||||
| load_param_into_net(network, param_dict) | load_param_into_net(network, param_dict) | ||||
| @@ -50,10 +50,10 @@ if __name__ == "__main__": | |||||
| # define funsion network | # define funsion network | ||||
| network = LeNet5Fusion(cfg.num_classes) | network = LeNet5Fusion(cfg.num_classes) | ||||
| # load aware quantizaiton network checkpoint | |||||
| # load quantization aware network checkpoint | |||||
| param_dict = load_checkpoint(args.ckpt_path) | param_dict = load_checkpoint(args.ckpt_path) | ||||
| load_param_into_net(network, param_dict) | load_param_into_net(network, param_dict) | ||||
| # convert funsion netwrok to aware quantizaiton network | |||||
| # convert funsion netwrok to quantization aware network | |||||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | ||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | ||||
| @@ -18,14 +18,14 @@ | |||||
| import sys | import sys | ||||
| class _MindsporeTestFrameworkkeyword: | |||||
| class _MindSporeTestFrameworkkeyword: | |||||
| def __setattr__(self, name, value): | def __setattr__(self, name, value): | ||||
| if name in self.__dict__: | if name in self.__dict__: | ||||
| raise TypeError("can not rebind keyword (%s)" % name) | raise TypeError("can not rebind keyword (%s)" % name) | ||||
| self.__dict__[name] = value | self.__dict__[name] = value | ||||
| keyword = _MindsporeTestFrameworkkeyword() | |||||
| keyword = _MindSporeTestFrameworkkeyword() | |||||
| keyword.function = "function" | keyword.function = "function" | ||||
| keyword.inputs = "inputs" | keyword.inputs = "inputs" | ||||