Merge pull request !28546 from kisnwang/add-cpu-adafactorfeature/build-system-rewrite
| @@ -24,8 +24,8 @@ namespace { | |||
| static constexpr size_t kSizeFloat32 = sizeof(float); | |||
| static constexpr size_t kSizeFloat16 = sizeof(float16); | |||
| static constexpr size_t kScalarIndex = 0; | |||
| static constexpr size_t kFusedAdaFactorInputNum = 12; | |||
| static constexpr size_t kFusedAdaFactorWorkSpaceNum = 3; | |||
| static constexpr size_t kStandardInputNum = 12; | |||
| static constexpr size_t kWorkSpaceNum = 3; | |||
| static constexpr size_t kBatchSize = 10000; | |||
| static auto constexpr kEnableScaleParameter = "enable_scale_parameter"; | |||
| static auto constexpr kEnableFirstMoment = "enable_first_moment"; | |||
| @@ -37,15 +37,9 @@ static constexpr float kEps = 1e-30; | |||
| void FusedAdaFactorCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| CPUKernel::InitInputOutputSize(kernel_node); | |||
| if (param_dtype_ == kNumberTypeFloat16) { | |||
| (void)workspace_size_list_.emplace_back(elem_num_ * kSizeFloat16); | |||
| (void)workspace_size_list_.emplace_back(elem_num_ / last_row_dim_size_ * kSizeFloat16); | |||
| (void)workspace_size_list_.emplace_back(elem_num_ / last_col_dim_size_ * kSizeFloat16); | |||
| } else { | |||
| (void)workspace_size_list_.emplace_back(elem_num_ * kSizeFloat32); | |||
| (void)workspace_size_list_.emplace_back(elem_num_ / last_row_dim_size_ * kSizeFloat32); | |||
| (void)workspace_size_list_.emplace_back(elem_num_ / last_col_dim_size_ * kSizeFloat32); | |||
| } | |||
| (void)workspace_size_list_.emplace_back(elem_num_ * kSizeFloat32); | |||
| (void)workspace_size_list_.emplace_back(elem_num_ / last_row_dim_size_ * kSizeFloat32); | |||
| (void)workspace_size_list_.emplace_back(elem_num_ / last_col_dim_size_ * kSizeFloat32); | |||
| } | |||
| void FusedAdaFactorCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| @@ -93,14 +87,14 @@ float FusedAdaFactorCPUKernel::CalcRMS(T *input, size_t elem_num) { | |||
| } | |||
| template <typename T> | |||
| void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressPtr> &inputs, | |||
| void FusedAdaFactorCPUKernel::FactorUpdate(float *update, const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &workspaces) { | |||
| auto beta2t = reinterpret_cast<float *>(inputs[BETA2T]->addr)[kScalarIndex]; | |||
| auto grad = reinterpret_cast<T *>(inputs[GRAD]->addr); | |||
| auto exp_avg_sq_row = reinterpret_cast<T *>(inputs[EXP_AVG_SQ_ROW]->addr); | |||
| auto exp_avg_sq_col = reinterpret_cast<T *>(inputs[EXP_AVG_SQ_COL]->addr); | |||
| auto r_factor = reinterpret_cast<T *>(workspaces[R_FACTOR]->addr); | |||
| auto c_factor = reinterpret_cast<T *>(workspaces[C_FACTOR]->addr); | |||
| auto r_factor = reinterpret_cast<float *>(workspaces[R_FACTOR]->addr); | |||
| auto c_factor = reinterpret_cast<float *>(workspaces[C_FACTOR]->addr); | |||
| auto one_minus_beta2t = 1 - beta2t; | |||
| std::function<void(size_t, size_t)> task; | |||
| @@ -115,7 +109,7 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP | |||
| float row_reduce = 0; | |||
| size_t reduce_start = i * row_dim_size; | |||
| for (size_t j = 0; j < row_dim_size; ++j) { | |||
| row_reduce += static_cast<float>(update[reduce_start + j]); | |||
| row_reduce += update[reduce_start + j]; | |||
| } | |||
| row_reduce = row_reduce / row_dim_size; | |||
| auto tmp = static_cast<float>(exp_avg_sq_row[i]) * beta2t + row_reduce * one_minus_beta2t; | |||
| @@ -135,8 +129,7 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP | |||
| col_reduce /= col_dim_size; | |||
| col_reduce = std::max(col_reduce, kEps); | |||
| for (size_t j = 0; j < col_dim_size; ++j) { | |||
| auto tmp = std::sqrt(static_cast<float>(exp_avg_sq_row[reduce_start + j]) / col_reduce); | |||
| r_factor[reduce_start + j] = static_cast<T>(tmp); | |||
| r_factor[reduce_start + j] = std::sqrt(static_cast<float>(exp_avg_sq_row[reduce_start + j]) / col_reduce); | |||
| } | |||
| } | |||
| }; | |||
| @@ -149,13 +142,13 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP | |||
| float row_reduce = 0; | |||
| size_t reduce_start = i / row_dim_size * last_row_col_size + i % row_dim_size; | |||
| for (size_t j = 0; j < col_dim_size; ++j) { | |||
| row_reduce += static_cast<float>(update[reduce_start + j * row_dim_size]); | |||
| row_reduce += update[reduce_start + j * row_dim_size]; | |||
| } | |||
| row_reduce = row_reduce / col_dim_size; | |||
| auto tmp = static_cast<float>(exp_avg_sq_col[i]) * beta2t + row_reduce * one_minus_beta2t; | |||
| tmp = std::max(tmp, kEps); | |||
| exp_avg_sq_col[i] = static_cast<T>(tmp); | |||
| c_factor[i] = static_cast<T>(std::sqrt(tmp)); | |||
| c_factor[i] = std::sqrt(tmp); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, exp_avg_sq_col_elem_num, kBatchSize); | |||
| @@ -166,12 +159,8 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP | |||
| size_t row_i = i % row_dim_size; | |||
| size_t col_i = i / row_dim_size % col_dim_size; | |||
| size_t slice = i / last_row_col_size; | |||
| auto left = static_cast<float>(r_factor[slice * col_dim_size + col_i]); | |||
| auto right = static_cast<float>(c_factor[slice * row_dim_size + row_i]); | |||
| auto norm = left * right; | |||
| norm = std::max(norm, kEps); | |||
| auto tmp = static_cast<float>(grad[i]) / norm; | |||
| update[i] = static_cast<T>(tmp); | |||
| auto norm = r_factor[slice * col_dim_size + col_i] * c_factor[slice * row_dim_size + row_i]; | |||
| update[i] = static_cast<float>(grad[i]) * global_norm_reciprocal_ / std::max(norm, kEps); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize); | |||
| @@ -190,7 +179,7 @@ void FusedAdaFactorCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||
| auto param = reinterpret_cast<T *>(inputs[PARAM]->addr); | |||
| auto exp_avg = reinterpret_cast<T *>(inputs[EXP_AVG]->addr); | |||
| auto exp_avg_sq = reinterpret_cast<T *>(inputs[EXP_AVG_SQ]->addr); | |||
| auto update = reinterpret_cast<T *>(workspaces[UPDATE]->addr); | |||
| auto update = reinterpret_cast<float *>(workspaces[UPDATE]->addr); | |||
| auto one_minus_beta1 = 1 - beta1; | |||
| auto one_minus_beta2t = 1 - beta2t; | |||
| if (clip_threshold <= 0) { | |||
| @@ -211,23 +200,22 @@ void FusedAdaFactorCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||
| // update = grad * grad + eps[0] | |||
| task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; ++i) { | |||
| auto tmp = static_cast<float>(grad[i]); | |||
| update[i] = static_cast<T>(tmp * tmp + epsilon[0]); | |||
| auto tmp = static_cast<float>(grad[i]) * global_norm_reciprocal_; | |||
| update[i] = tmp * tmp + epsilon[0]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize); | |||
| if (need_factor_) { | |||
| FactorUpdate(update, inputs, workspaces); | |||
| FactorUpdate<T>(update, inputs, workspaces); | |||
| } else { | |||
| // no factor | |||
| task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; ++i) { | |||
| auto tmp = static_cast<float>(exp_avg_sq[i]) * beta2t + static_cast<float>(update[i]) * one_minus_beta2t; | |||
| auto tmp = static_cast<float>(exp_avg_sq[i]) * beta2t + update[i] * one_minus_beta2t; | |||
| tmp = std::max(tmp, kEps); | |||
| exp_avg_sq[i] = static_cast<T>(tmp); | |||
| tmp = static_cast<float>(grad[i]) / std::sqrt(static_cast<float>(exp_avg_sq[i])); | |||
| update[i] = static_cast<T>(tmp); | |||
| update[i] = static_cast<float>(grad[i]) * global_norm_reciprocal_ / std::sqrt(tmp); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize); | |||
| @@ -244,18 +232,16 @@ void FusedAdaFactorCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||
| auto update_coff = learning_rate / std::max(update_rms_thres, 1.0f); | |||
| task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; ++i) { | |||
| auto tmp = static_cast<float>(update[i]) * update_coff; | |||
| update[i] = static_cast<T>(tmp); | |||
| update[i] = update[i] * update_coff; | |||
| if (enable_first_moment_) { | |||
| tmp = static_cast<float>(exp_avg[i]) * beta1 + static_cast<float>(update[i]) * one_minus_beta1; | |||
| exp_avg[i] = static_cast<T>(tmp); | |||
| update[i] = exp_avg[i]; | |||
| update[i] = static_cast<float>(exp_avg[i]) * beta1 + update[i] * one_minus_beta1; | |||
| exp_avg[i] = static_cast<T>(update[i]); | |||
| } | |||
| if (enable_weight_decay_) { | |||
| tmp = static_cast<float>(param[i]) * weight_decay * learning_rate; | |||
| param[i] = param[i] - update[i] - static_cast<T>(tmp); | |||
| auto tmp = static_cast<float>(param[i]) * weight_decay * learning_rate; | |||
| param[i] = static_cast<T>(static_cast<float>(param[i]) - update[i] - tmp); | |||
| } else { | |||
| param[i] = param[i] - update[i]; | |||
| param[i] = static_cast<T>(static_cast<float>(param[i]) - update[i]); | |||
| } | |||
| } | |||
| }; | |||
| @@ -265,15 +251,17 @@ void FusedAdaFactorCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||
| bool FusedAdaFactorCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &workspaces, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (inputs.size() != kFusedAdaFactorInputNum) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kFusedAdaFactorInputNum | |||
| << ", but got: " << inputs.size(); | |||
| } | |||
| if (workspaces.size() != kFusedAdaFactorWorkSpaceNum) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of workspaces should be " | |||
| << kFusedAdaFactorWorkSpaceNum << ", but got: " << workspaces.size(); | |||
| if (inputs.size() == kStandardInputNum + 1) { | |||
| auto global_norm = reinterpret_cast<float *>(inputs[GLOBAL_NORM]->addr)[kScalarIndex]; | |||
| if (global_norm < kEps) { | |||
| global_norm_reciprocal_ = 1.0f; | |||
| } else { | |||
| global_norm_reciprocal_ = 1.0f / global_norm; | |||
| } | |||
| } | |||
| CheckParam(inputs, workspaces, outputs); | |||
| CheckInputAddresses(inputs); | |||
| CheckWorkspaceAddresses(workspaces); | |||
| if (param_dtype_ == kNumberTypeFloat16) { | |||
| LaunchKernel<float16>(inputs, workspaces, outputs); | |||
| } else { | |||
| @@ -282,9 +270,12 @@ bool FusedAdaFactorCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu | |||
| return true; | |||
| } | |||
| void FusedAdaFactorCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &workspaces, | |||
| const std::vector<kernel::AddressPtr> &) const { | |||
| void FusedAdaFactorCPUKernel::CheckInputAddresses(const std::vector<kernel::AddressPtr> &inputs) const { | |||
| if (inputs.size() < kStandardInputNum) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be at least " << kStandardInputNum | |||
| << ", but got: " << inputs.size(); | |||
| } | |||
| if (inputs[EPSILON]->size != kSizeFloat32 << 1) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'epsilon' should be " << (kSizeFloat32 << 1) | |||
| << ", but got " << inputs[EPSILON]->size; | |||
| @@ -293,7 +284,6 @@ void FusedAdaFactorCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> & | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta1' should be " << kSizeFloat32 | |||
| << ", but got " << inputs[BETA1]->size; | |||
| } | |||
| if (inputs[BETA1]->size != kSizeFloat32) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta1' should be " << kSizeFloat32 | |||
| << ", but got " << inputs[BETA1]->size; | |||
| @@ -320,10 +310,6 @@ void FusedAdaFactorCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> & | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'gradient' should be " << param_size | |||
| << ", but got " << inputs[GRAD]->size; | |||
| } | |||
| if (workspaces[UPDATE]->size != param_size) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'update ' should be " << param_size | |||
| << ", but got " << workspaces[0]->size; | |||
| } | |||
| if (enable_first_moment_ && inputs[EXP_AVG]->size != param_size) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg' should be " << param_size | |||
| @@ -347,5 +333,28 @@ void FusedAdaFactorCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> & | |||
| << param_size / last_col_dim_size_ << ", but got " << inputs[EXP_AVG_SQ_COL]->size; | |||
| } | |||
| } | |||
| void FusedAdaFactorCPUKernel::CheckWorkspaceAddresses(const std::vector<kernel::AddressPtr> &workspaces) const { | |||
| if (workspaces.size() != kWorkSpaceNum) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of workspaces should be " << kWorkSpaceNum | |||
| << ", but got: " << workspaces.size(); | |||
| } | |||
| size_t update_size = elem_num_ * kSizeFloat32; | |||
| if (workspaces[UPDATE]->size != elem_num_ * kSizeFloat32) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'update ' should be " << update_size | |||
| << ", but got " << workspaces[0]->size; | |||
| } | |||
| if (workspaces[R_FACTOR]->size != update_size / last_row_dim_size_) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'r_factor' should be " | |||
| << update_size / last_row_dim_size_ << ", but got " << workspaces[R_FACTOR]->size; | |||
| } | |||
| if (workspaces[C_FACTOR]->size != update_size / last_col_dim_size_) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'c_factor' should be " | |||
| << update_size / last_col_dim_size_ << ", but got " << workspaces[C_FACTOR]->size; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -33,8 +33,9 @@ class FusedAdaFactorCPUKernel : public CPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces, | |||
| const std::vector<AddressPtr> &outputs) const; | |||
| void CheckInputAddresses(const std::vector<AddressPtr> &inputs) const; | |||
| void CheckWorkspaceAddresses(const std::vector<AddressPtr> &workspaces) const; | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces, | |||
| const std::vector<AddressPtr> &outputs); | |||
| @@ -43,7 +44,7 @@ class FusedAdaFactorCPUKernel : public CPUKernel { | |||
| float CalcRMS(T *input, size_t elem_num); | |||
| template <typename T> | |||
| void FactorUpdate(T *update, const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces); | |||
| void FactorUpdate(float *update, const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces); | |||
| bool enable_scale_parameter_{false}; | |||
| bool enable_first_moment_{false}; | |||
| @@ -53,6 +54,7 @@ class FusedAdaFactorCPUKernel : public CPUKernel { | |||
| size_t last_row_dim_size_{1}; | |||
| size_t last_col_dim_size_{1}; | |||
| TypeId param_dtype_{kTypeUnknown}; | |||
| float global_norm_reciprocal_{1.0f}; | |||
| enum InputEnum { | |||
| EPSILON, | |||
| @@ -66,7 +68,8 @@ class FusedAdaFactorCPUKernel : public CPUKernel { | |||
| EXP_AVG, | |||
| EXP_AVG_SQ_ROW, | |||
| EXP_AVG_SQ_COL, | |||
| EXP_AVG_SQ | |||
| EXP_AVG_SQ, | |||
| GLOBAL_NORM | |||
| }; | |||
| enum WorkspaceEnum { UPDATE, R_FACTOR, C_FACTOR }; | |||
| @@ -105,6 +108,42 @@ MS_REG_CPU_KERNEL(FusedAdaFactor, | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| FusedAdaFactorCPUKernel) | |||
| MS_REG_CPU_KERNEL(FusedAdaFactorWithGlobalNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedAdaFactorCPUKernel) | |||
| MS_REG_CPU_KERNEL(FusedAdaFactorWithGlobalNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| FusedAdaFactorCPUKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -215,6 +215,7 @@ constexpr auto kAdamWeightDecayName = "AdamWeightDecay"; | |||
| constexpr auto kFusedCastAdamWeightDecayName = "FusedCastAdamWeightDecay"; | |||
| constexpr auto kFusedAdamName = "FusedAdam"; | |||
| constexpr auto kFusedAdaFactorName = "FusedAdaFactor"; | |||
| constexpr auto kFusedAdaFactorWithGlobalNormName = "FusedAdaFactorWithGlobalNorm"; | |||
| constexpr auto kFusedSparseAdamName = "FusedSparseAdam"; | |||
| constexpr auto kFusedMatMulBiasAddName = "FusedMatMulBiasAdd"; | |||
| constexpr auto kDeadNodeName = "DeadNode"; | |||
| @@ -693,6 +694,7 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName, | |||
| kFusedCastAdamWeightDecayName, | |||
| kFusedAdamName, | |||
| kFusedAdaFactorName, | |||
| kFusedAdaFactorWithGlobalNormName, | |||
| kFusedSparseAdamName, | |||
| kFusedMulApplyMomentumOpName, | |||
| kFusedWeightScaleApplyMomentum, | |||
| @@ -35,7 +35,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta | |||
| BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, | |||
| EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted, | |||
| TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches) | |||
| from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter, Broadcast, | |||
| from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter, | |||
| Broadcast, | |||
| _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, | |||
| _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad, | |||
| _HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather) | |||
| @@ -44,7 +45,8 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm | |||
| from .control_ops import GeSwitch, Merge | |||
| from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, | |||
| MakeRefKey, | |||
| FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay, FusedAdaFactor) | |||
| FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay, FusedAdaFactor, | |||
| FusedAdaFactorWithGlobalNorm) | |||
| from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, | |||
| BitwiseAnd, BitwiseOr, Ger, | |||
| @@ -178,6 +180,7 @@ __all__ = [ | |||
| 'FusedSparseLazyAdam', | |||
| 'AdamNoUpdateParam', | |||
| 'FusedAdaFactor', | |||
| 'FusedAdaFactorWithGlobalNorm', | |||
| 'Softplus', | |||
| 'Softmax', | |||
| 'Softsign', | |||
| @@ -689,3 +689,25 @@ class FusedAdaFactor(PrimitiveWithInfer): | |||
| learning_rate_type, grad_type, param_type, exp_avg_type, exp_avg_sq_row_type, | |||
| exp_avg_sq_col_type, exp_avg_sq_type): | |||
| return param_type | |||
| class FusedAdaFactorWithGlobalNorm(FusedAdaFactor): | |||
| r""" | |||
| Divide global norm for gradient in FusedAdaFactor, and refer to super class for FusedAdaFactor details | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, enable_scale_parameter=False, enable_first_moment=False, enable_weight_decay=False): | |||
| super(FusedAdaFactorWithGlobalNorm, self).__init__(enable_scale_parameter, enable_first_moment, | |||
| enable_weight_decay) | |||
| def infer_shape(self, epsilon_shape, clip_threshold_shape, beta1_shape, beta2t_shape, weight_decay_shape, | |||
| learning_rate_shape, grad_shape, param_shape, exp_avg_shape, exp_avg_sq_row_shape, | |||
| exp_avg_sq_col_shape, exp_avg_sq_shape, global_norm_shape): | |||
| validator.check("grad_shape", grad_shape, "param_shape", param_shape, Rel.EQ, self.name) | |||
| return param_shape | |||
| def infer_dtype(self, epsilon_type, clip_threshold_type, beta1_type, beta2t_type, weight_decay_type, | |||
| learning_rate_type, grad_type, param_type, exp_avg_type, exp_avg_sq_row_type, | |||
| exp_avg_sq_col_type, exp_avg_sq_type, global_norm_type): | |||
| return param_type | |||
| @@ -39,6 +39,17 @@ class Net(nn.Cell): | |||
| return out | |||
| class NetWithGlobalNorm(Net): | |||
| def __init__(self): | |||
| super(NetWithGlobalNorm, self).__init__() | |||
| self.opt = ops.FusedAdaFactorWithGlobalNorm() | |||
| def construct(self, epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad, global_norm): | |||
| out = self.opt(epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad, self.param, self.exp_avg, | |||
| self.exp_avg_sq_row, self.exp_avg_sq_col, self.exp_avg_sq, global_norm) | |||
| return out | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @@ -54,3 +65,20 @@ def test_adafactor(): | |||
| net((1e-30, 1e-3), 1.0, 0.9, 0.8, 1e-2, 0.03, gradient) | |||
| diff = net.param.asnumpy() - np.ones(param_shape) * 0.97 | |||
| assert np.all(diff < 1e-3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_adafactor_with_global_norm(): | |||
| ''' | |||
| Feature: AdaFactor | |||
| Description: Test AdaFactor | |||
| Expectation: Run success | |||
| ''' | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| net = NetWithGlobalNorm() | |||
| gradient = Tensor(np.ones(param_shape), mstype.float32) | |||
| net((1e-30, 1e-3), 1.0, 0.9, 0.8, 1e-2, 0.03, gradient, 10.0) | |||
| diff = net.param.asnumpy() - np.ones(param_shape) * 0.97 | |||
| assert np.all(diff < 1e-3) | |||
| @@ -57,7 +57,7 @@ class FusedAdaFactorCpuKernelTest : public UT::Common { | |||
| grad_.resize(elem_num_); | |||
| exp_avg_.resize(elem_num_); | |||
| exp_avg_sq_.resize(elem_num_); | |||
| update_.resize(elem_num_); | |||
| update_.resize(elem_num_, 0.0f); | |||
| for (size_t i = 0; i < elem_num_; ++i) { | |||
| auto ptr = (float16 *)param_.data(); | |||
| ptr[i] = static_cast<float16>(1.0f); | |||
| @@ -67,8 +67,6 @@ class FusedAdaFactorCpuKernelTest : public UT::Common { | |||
| ptr[i] = static_cast<float16>(0.0f); | |||
| ptr = (float16 *)exp_avg_sq_.data(); | |||
| ptr[i] = static_cast<float16>(0.0f); | |||
| ptr = (float16 *)update_.data(); | |||
| ptr[i] = static_cast<float16>(0.0f); | |||
| } | |||
| auto r_factor_num = elem_num_ / last_row_dim_size_; | |||
| @@ -77,8 +75,6 @@ class FusedAdaFactorCpuKernelTest : public UT::Common { | |||
| for (size_t i = 0; i < r_factor_num; ++i) { | |||
| auto ptr = (float16 *)exp_avg_sq_row_.data(); | |||
| ptr[i] = static_cast<float16>(0.0f); | |||
| ptr = (float16 *)r_factor_.data(); | |||
| ptr[i] = static_cast<float16>(0.0f); | |||
| } | |||
| auto c_factor_num = elem_num_ / last_col_dim_size_; | |||
| @@ -87,8 +83,6 @@ class FusedAdaFactorCpuKernelTest : public UT::Common { | |||
| for (size_t i = 0; i < c_factor_num; ++i) { | |||
| auto ptr = (float16 *)exp_avg_sq_col_.data(); | |||
| ptr[i] = static_cast<float16>(0.0f); | |||
| ptr = (float16 *)c_factor_.data(); | |||
| ptr[i] = static_cast<float16>(0.0f); | |||
| } | |||
| } | |||
| @@ -99,7 +93,7 @@ class FusedAdaFactorCpuKernelTest : public UT::Common { | |||
| return kernel_addr; | |||
| } | |||
| void CreateAddress() { | |||
| void CreateAddress(bool enable_global_norm) { | |||
| constexpr size_t eps_num = 2; | |||
| inputs_.push_back(CreateKernelAddress(epsilon_.data(), eps_num, kSizeFloat32)); | |||
| inputs_.push_back(CreateKernelAddress(&clip_threshold_, 1, kSizeFloat32)); | |||
| @@ -113,36 +107,37 @@ class FusedAdaFactorCpuKernelTest : public UT::Common { | |||
| inputs_.push_back(CreateKernelAddress(exp_avg_sq_row_.data(), elem_num_ / last_row_dim_size_, type_size_)); | |||
| inputs_.push_back(CreateKernelAddress(exp_avg_sq_col_.data(), elem_num_ / last_col_dim_size_, type_size_)); | |||
| inputs_.push_back(CreateKernelAddress(exp_avg_sq_.data(), elem_num_, type_size_)); | |||
| workspace_.push_back(CreateKernelAddress(update_.data(), elem_num_, type_size_)); | |||
| workspace_.push_back(CreateKernelAddress(r_factor_.data(), elem_num_ / last_row_dim_size_, type_size_)); | |||
| workspace_.push_back(CreateKernelAddress(c_factor_.data(), elem_num_ / last_col_dim_size_, type_size_)); | |||
| workspace_.push_back(CreateKernelAddress(update_.data(), elem_num_, kSizeFloat32)); | |||
| workspace_.push_back(CreateKernelAddress(r_factor_.data(), elem_num_ / last_row_dim_size_, kSizeFloat32)); | |||
| workspace_.push_back(CreateKernelAddress(c_factor_.data(), elem_num_ / last_col_dim_size_, kSizeFloat32)); | |||
| if (enable_global_norm) { | |||
| inputs_.push_back(CreateKernelAddress(&global_norm_, 1, kSizeFloat32)); | |||
| } | |||
| } | |||
| void ComputeFp32() { | |||
| void ComputeFp32(bool enable_global_norm) { | |||
| ada_factor_->param_dtype_ = kNumberTypeFloat32; | |||
| type_size_ = sizeof(float); | |||
| InitDataFp32(); | |||
| CreateAddress(); | |||
| CreateAddress(enable_global_norm); | |||
| ada_factor_->Launch(inputs_, workspace_, outputs_); | |||
| constexpr float result = 0.97; | |||
| for (size_t i = 0; i < elem_num_; ++i) { | |||
| EXPECT_TRUE(std::fabs(param_[i] - result) < 1e-6); | |||
| EXPECT_TRUE(std::fabs(param_[i] - result_) < 1e-6); | |||
| } | |||
| } | |||
| void ComputeFp16() { | |||
| void ComputeFp16(bool enable_global_norm) { | |||
| ada_factor_->param_dtype_ = kNumberTypeFloat16; | |||
| type_size_ = sizeof(float16); | |||
| InitDataFp16(); | |||
| CreateAddress(); | |||
| CreateAddress(enable_global_norm); | |||
| ada_factor_->Launch(inputs_, workspace_, outputs_); | |||
| constexpr float result = 0.97; | |||
| auto ptr = (float16 *)param_.data(); | |||
| for (size_t i = 0; i < elem_num_; ++i) { | |||
| EXPECT_TRUE(std::fabs(static_cast<float>(ptr[i]) - result) < 1e-3); | |||
| EXPECT_TRUE(std::fabs(static_cast<float>(ptr[i]) - result_) < 1e-3); | |||
| } | |||
| } | |||
| @@ -152,6 +147,8 @@ class FusedAdaFactorCpuKernelTest : public UT::Common { | |||
| float beta1_ = 0.9; | |||
| float beta2t_ = 0.8; | |||
| float weight_decay_ = 1e-2; | |||
| float global_norm_ = 10.0f; | |||
| float result_ = 0.97; | |||
| std::vector<float> param_; | |||
| std::vector<float> grad_; | |||
| std::vector<float> exp_avg_; | |||
| @@ -179,7 +176,7 @@ class FusedAdaFactorCpuKernelTest : public UT::Common { | |||
| /// Expectation: pass | |||
| TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_factor) { | |||
| ada_factor_->need_factor_ = true; | |||
| ComputeFp32(); | |||
| ComputeFp32(false); | |||
| } | |||
| /// Feature: FusedAdaFactor | |||
| @@ -187,7 +184,23 @@ TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_factor) { | |||
| /// Expectation: pass | |||
| TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_no_factor) { | |||
| ada_factor_->need_factor_ = false; | |||
| ComputeFp32(); | |||
| ComputeFp32(false); | |||
| } | |||
| /// Feature: FusedAdaFactor | |||
| /// Description: Run FusedAdaFactor that needs factor state with fp32 data inputs and global norm | |||
| /// Expectation: pass | |||
| TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_factor_global_norm) { | |||
| ada_factor_->need_factor_ = true; | |||
| ComputeFp32(true); | |||
| } | |||
| /// Feature: FusedAdaFactor | |||
| /// Description: Run FusedAdaFactor that doesn't need factor state with fp32 data inputs and global norm | |||
| /// Expectation: pass | |||
| TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_no_factor_global_norm) { | |||
| ada_factor_->need_factor_ = false; | |||
| ComputeFp32(true); | |||
| } | |||
| /// Feature: FusedAdaFactor | |||
| @@ -195,7 +208,7 @@ TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_no_factor) { | |||
| /// Expectation: pass | |||
| TEST_F(FusedAdaFactorCpuKernelTest, compute_fp16_factor) { | |||
| ada_factor_->need_factor_ = true; | |||
| ComputeFp16(); | |||
| ComputeFp16(false); | |||
| } | |||
| /// Feature: FusedAdaFactor | |||
| @@ -203,7 +216,23 @@ TEST_F(FusedAdaFactorCpuKernelTest, compute_fp16_factor) { | |||
| /// Expectation: pass | |||
| TEST_F(FusedAdaFactorCpuKernelTest, compute_fp16_no_factor) { | |||
| ada_factor_->need_factor_ = false; | |||
| ComputeFp16(); | |||
| ComputeFp16(false); | |||
| } | |||
| /// Feature: FusedAdaFactor | |||
| /// Description: Run FusedAdaFactor that needs factor state with fp16 data inputs and global norm | |||
| /// Expectation: pass | |||
| TEST_F(FusedAdaFactorCpuKernelTest, compute_fp16_factor_global_norm) { | |||
| ada_factor_->need_factor_ = true; | |||
| ComputeFp16(true); | |||
| } | |||
| /// Feature: FusedAdaFactor | |||
| /// Description: Run FusedAdaFactor that doesn't need factor state with fp16 data inputs and global norm | |||
| /// Expectation: pass | |||
| TEST_F(FusedAdaFactorCpuKernelTest, compute_fp16_no_factor_global_norm) { | |||
| ada_factor_->need_factor_ = false; | |||
| ComputeFp16(true); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||