Browse Source

!28546 fused global norm for cpu adafactor

Merge pull request !28546 from kisnwang/add-cpu-adafactor
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
7384cb2689
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 217 additions and 85 deletions
  1. +65
    -56
      mindspore/ccsrc/backend/kernel_compiler/cpu/fused_ada_factor_cpu_kernel.cc
  2. +43
    -4
      mindspore/ccsrc/backend/kernel_compiler/cpu/fused_ada_factor_cpu_kernel.h
  3. +2
    -0
      mindspore/ccsrc/utils/utils.h
  4. +5
    -2
      mindspore/python/mindspore/ops/operations/__init__.py
  5. +22
    -0
      mindspore/python/mindspore/ops/operations/inner_ops.py
  6. +28
    -0
      tests/st/ops/cpu/test_fused_ada_factor_op.py
  7. +52
    -23
      tests/ut/cpp/kernel/cpu/fused_ada_factor_cpu_kernel_test.cc

+ 65
- 56
mindspore/ccsrc/backend/kernel_compiler/cpu/fused_ada_factor_cpu_kernel.cc View File

@@ -24,8 +24,8 @@ namespace {
static constexpr size_t kSizeFloat32 = sizeof(float);
static constexpr size_t kSizeFloat16 = sizeof(float16);
static constexpr size_t kScalarIndex = 0;
static constexpr size_t kFusedAdaFactorInputNum = 12;
static constexpr size_t kFusedAdaFactorWorkSpaceNum = 3;
static constexpr size_t kStandardInputNum = 12;
static constexpr size_t kWorkSpaceNum = 3;
static constexpr size_t kBatchSize = 10000;
static auto constexpr kEnableScaleParameter = "enable_scale_parameter";
static auto constexpr kEnableFirstMoment = "enable_first_moment";
@@ -37,15 +37,9 @@ static constexpr float kEps = 1e-30;

void FusedAdaFactorCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node);
if (param_dtype_ == kNumberTypeFloat16) {
(void)workspace_size_list_.emplace_back(elem_num_ * kSizeFloat16);
(void)workspace_size_list_.emplace_back(elem_num_ / last_row_dim_size_ * kSizeFloat16);
(void)workspace_size_list_.emplace_back(elem_num_ / last_col_dim_size_ * kSizeFloat16);
} else {
(void)workspace_size_list_.emplace_back(elem_num_ * kSizeFloat32);
(void)workspace_size_list_.emplace_back(elem_num_ / last_row_dim_size_ * kSizeFloat32);
(void)workspace_size_list_.emplace_back(elem_num_ / last_col_dim_size_ * kSizeFloat32);
}
(void)workspace_size_list_.emplace_back(elem_num_ * kSizeFloat32);
(void)workspace_size_list_.emplace_back(elem_num_ / last_row_dim_size_ * kSizeFloat32);
(void)workspace_size_list_.emplace_back(elem_num_ / last_col_dim_size_ * kSizeFloat32);
}

void FusedAdaFactorCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@@ -93,14 +87,14 @@ float FusedAdaFactorCPUKernel::CalcRMS(T *input, size_t elem_num) {
}

template <typename T>
void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressPtr> &inputs,
void FusedAdaFactorCPUKernel::FactorUpdate(float *update, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspaces) {
auto beta2t = reinterpret_cast<float *>(inputs[BETA2T]->addr)[kScalarIndex];
auto grad = reinterpret_cast<T *>(inputs[GRAD]->addr);
auto exp_avg_sq_row = reinterpret_cast<T *>(inputs[EXP_AVG_SQ_ROW]->addr);
auto exp_avg_sq_col = reinterpret_cast<T *>(inputs[EXP_AVG_SQ_COL]->addr);
auto r_factor = reinterpret_cast<T *>(workspaces[R_FACTOR]->addr);
auto c_factor = reinterpret_cast<T *>(workspaces[C_FACTOR]->addr);
auto r_factor = reinterpret_cast<float *>(workspaces[R_FACTOR]->addr);
auto c_factor = reinterpret_cast<float *>(workspaces[C_FACTOR]->addr);
auto one_minus_beta2t = 1 - beta2t;

std::function<void(size_t, size_t)> task;
@@ -115,7 +109,7 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP
float row_reduce = 0;
size_t reduce_start = i * row_dim_size;
for (size_t j = 0; j < row_dim_size; ++j) {
row_reduce += static_cast<float>(update[reduce_start + j]);
row_reduce += update[reduce_start + j];
}
row_reduce = row_reduce / row_dim_size;
auto tmp = static_cast<float>(exp_avg_sq_row[i]) * beta2t + row_reduce * one_minus_beta2t;
@@ -135,8 +129,7 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP
col_reduce /= col_dim_size;
col_reduce = std::max(col_reduce, kEps);
for (size_t j = 0; j < col_dim_size; ++j) {
auto tmp = std::sqrt(static_cast<float>(exp_avg_sq_row[reduce_start + j]) / col_reduce);
r_factor[reduce_start + j] = static_cast<T>(tmp);
r_factor[reduce_start + j] = std::sqrt(static_cast<float>(exp_avg_sq_row[reduce_start + j]) / col_reduce);
}
}
};
@@ -149,13 +142,13 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP
float row_reduce = 0;
size_t reduce_start = i / row_dim_size * last_row_col_size + i % row_dim_size;
for (size_t j = 0; j < col_dim_size; ++j) {
row_reduce += static_cast<float>(update[reduce_start + j * row_dim_size]);
row_reduce += update[reduce_start + j * row_dim_size];
}
row_reduce = row_reduce / col_dim_size;
auto tmp = static_cast<float>(exp_avg_sq_col[i]) * beta2t + row_reduce * one_minus_beta2t;
tmp = std::max(tmp, kEps);
exp_avg_sq_col[i] = static_cast<T>(tmp);
c_factor[i] = static_cast<T>(std::sqrt(tmp));
c_factor[i] = std::sqrt(tmp);
}
};
CPUKernelUtils::ParallelFor(task, exp_avg_sq_col_elem_num, kBatchSize);
@@ -166,12 +159,8 @@ void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressP
size_t row_i = i % row_dim_size;
size_t col_i = i / row_dim_size % col_dim_size;
size_t slice = i / last_row_col_size;
auto left = static_cast<float>(r_factor[slice * col_dim_size + col_i]);
auto right = static_cast<float>(c_factor[slice * row_dim_size + row_i]);
auto norm = left * right;
norm = std::max(norm, kEps);
auto tmp = static_cast<float>(grad[i]) / norm;
update[i] = static_cast<T>(tmp);
auto norm = r_factor[slice * col_dim_size + col_i] * c_factor[slice * row_dim_size + row_i];
update[i] = static_cast<float>(grad[i]) * global_norm_reciprocal_ / std::max(norm, kEps);
}
};
CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
@@ -190,7 +179,7 @@ void FusedAdaFactorCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
auto param = reinterpret_cast<T *>(inputs[PARAM]->addr);
auto exp_avg = reinterpret_cast<T *>(inputs[EXP_AVG]->addr);
auto exp_avg_sq = reinterpret_cast<T *>(inputs[EXP_AVG_SQ]->addr);
auto update = reinterpret_cast<T *>(workspaces[UPDATE]->addr);
auto update = reinterpret_cast<float *>(workspaces[UPDATE]->addr);
auto one_minus_beta1 = 1 - beta1;
auto one_minus_beta2t = 1 - beta2t;
if (clip_threshold <= 0) {
@@ -211,23 +200,22 @@ void FusedAdaFactorCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
// update = grad * grad + eps[0]
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
auto tmp = static_cast<float>(grad[i]);
update[i] = static_cast<T>(tmp * tmp + epsilon[0]);
auto tmp = static_cast<float>(grad[i]) * global_norm_reciprocal_;
update[i] = tmp * tmp + epsilon[0];
}
};
CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);

if (need_factor_) {
FactorUpdate(update, inputs, workspaces);
FactorUpdate<T>(update, inputs, workspaces);
} else {
// no factor
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
auto tmp = static_cast<float>(exp_avg_sq[i]) * beta2t + static_cast<float>(update[i]) * one_minus_beta2t;
auto tmp = static_cast<float>(exp_avg_sq[i]) * beta2t + update[i] * one_minus_beta2t;
tmp = std::max(tmp, kEps);
exp_avg_sq[i] = static_cast<T>(tmp);
tmp = static_cast<float>(grad[i]) / std::sqrt(static_cast<float>(exp_avg_sq[i]));
update[i] = static_cast<T>(tmp);
update[i] = static_cast<float>(grad[i]) * global_norm_reciprocal_ / std::sqrt(tmp);
}
};
CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
@@ -244,18 +232,16 @@ void FusedAdaFactorCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
auto update_coff = learning_rate / std::max(update_rms_thres, 1.0f);
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
auto tmp = static_cast<float>(update[i]) * update_coff;
update[i] = static_cast<T>(tmp);
update[i] = update[i] * update_coff;
if (enable_first_moment_) {
tmp = static_cast<float>(exp_avg[i]) * beta1 + static_cast<float>(update[i]) * one_minus_beta1;
exp_avg[i] = static_cast<T>(tmp);
update[i] = exp_avg[i];
update[i] = static_cast<float>(exp_avg[i]) * beta1 + update[i] * one_minus_beta1;
exp_avg[i] = static_cast<T>(update[i]);
}
if (enable_weight_decay_) {
tmp = static_cast<float>(param[i]) * weight_decay * learning_rate;
param[i] = param[i] - update[i] - static_cast<T>(tmp);
auto tmp = static_cast<float>(param[i]) * weight_decay * learning_rate;
param[i] = static_cast<T>(static_cast<float>(param[i]) - update[i] - tmp);
} else {
param[i] = param[i] - update[i];
param[i] = static_cast<T>(static_cast<float>(param[i]) - update[i]);
}
}
};
@@ -265,15 +251,17 @@ void FusedAdaFactorCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
bool FusedAdaFactorCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspaces,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != kFusedAdaFactorInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kFusedAdaFactorInputNum
<< ", but got: " << inputs.size();
}
if (workspaces.size() != kFusedAdaFactorWorkSpaceNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of workspaces should be "
<< kFusedAdaFactorWorkSpaceNum << ", but got: " << workspaces.size();
if (inputs.size() == kStandardInputNum + 1) {
auto global_norm = reinterpret_cast<float *>(inputs[GLOBAL_NORM]->addr)[kScalarIndex];
if (global_norm < kEps) {
global_norm_reciprocal_ = 1.0f;
} else {
global_norm_reciprocal_ = 1.0f / global_norm;
}
}
CheckParam(inputs, workspaces, outputs);

CheckInputAddresses(inputs);
CheckWorkspaceAddresses(workspaces);
if (param_dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, workspaces, outputs);
} else {
@@ -282,9 +270,12 @@ bool FusedAdaFactorCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu
return true;
}

void FusedAdaFactorCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspaces,
const std::vector<kernel::AddressPtr> &) const {
void FusedAdaFactorCPUKernel::CheckInputAddresses(const std::vector<kernel::AddressPtr> &inputs) const {
if (inputs.size() < kStandardInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be at least " << kStandardInputNum
<< ", but got: " << inputs.size();
}

if (inputs[EPSILON]->size != kSizeFloat32 << 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'epsilon' should be " << (kSizeFloat32 << 1)
<< ", but got " << inputs[EPSILON]->size;
@@ -293,7 +284,6 @@ void FusedAdaFactorCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta1' should be " << kSizeFloat32
<< ", but got " << inputs[BETA1]->size;
}

if (inputs[BETA1]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta1' should be " << kSizeFloat32
<< ", but got " << inputs[BETA1]->size;
@@ -320,10 +310,6 @@ void FusedAdaFactorCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'gradient' should be " << param_size
<< ", but got " << inputs[GRAD]->size;
}
if (workspaces[UPDATE]->size != param_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'update ' should be " << param_size
<< ", but got " << workspaces[0]->size;
}

if (enable_first_moment_ && inputs[EXP_AVG]->size != param_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg' should be " << param_size
@@ -347,5 +333,28 @@ void FusedAdaFactorCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &
<< param_size / last_col_dim_size_ << ", but got " << inputs[EXP_AVG_SQ_COL]->size;
}
}

void FusedAdaFactorCPUKernel::CheckWorkspaceAddresses(const std::vector<kernel::AddressPtr> &workspaces) const {
if (workspaces.size() != kWorkSpaceNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of workspaces should be " << kWorkSpaceNum
<< ", but got: " << workspaces.size();
}

size_t update_size = elem_num_ * kSizeFloat32;

if (workspaces[UPDATE]->size != elem_num_ * kSizeFloat32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'update ' should be " << update_size
<< ", but got " << workspaces[0]->size;
}

if (workspaces[R_FACTOR]->size != update_size / last_row_dim_size_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'r_factor' should be "
<< update_size / last_row_dim_size_ << ", but got " << workspaces[R_FACTOR]->size;
}
if (workspaces[C_FACTOR]->size != update_size / last_col_dim_size_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'c_factor' should be "
<< update_size / last_col_dim_size_ << ", but got " << workspaces[C_FACTOR]->size;
}
}
} // namespace kernel
} // namespace mindspore

+ 43
- 4
mindspore/ccsrc/backend/kernel_compiler/cpu/fused_ada_factor_cpu_kernel.h View File

@@ -33,8 +33,9 @@ class FusedAdaFactorCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;

private:
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs) const;
void CheckInputAddresses(const std::vector<AddressPtr> &inputs) const;
void CheckWorkspaceAddresses(const std::vector<AddressPtr> &workspaces) const;

template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs);
@@ -43,7 +44,7 @@ class FusedAdaFactorCPUKernel : public CPUKernel {
float CalcRMS(T *input, size_t elem_num);

template <typename T>
void FactorUpdate(T *update, const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces);
void FactorUpdate(float *update, const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces);

bool enable_scale_parameter_{false};
bool enable_first_moment_{false};
@@ -53,6 +54,7 @@ class FusedAdaFactorCPUKernel : public CPUKernel {
size_t last_row_dim_size_{1};
size_t last_col_dim_size_{1};
TypeId param_dtype_{kTypeUnknown};
float global_norm_reciprocal_{1.0f};

enum InputEnum {
EPSILON,
@@ -66,7 +68,8 @@ class FusedAdaFactorCPUKernel : public CPUKernel {
EXP_AVG,
EXP_AVG_SQ_ROW,
EXP_AVG_SQ_COL,
EXP_AVG_SQ
EXP_AVG_SQ,
GLOBAL_NORM
};

enum WorkspaceEnum { UPDATE, R_FACTOR, C_FACTOR };
@@ -105,6 +108,42 @@ MS_REG_CPU_KERNEL(FusedAdaFactor,
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
FusedAdaFactorCPUKernel)

MS_REG_CPU_KERNEL(FusedAdaFactorWithGlobalNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedAdaFactorCPUKernel)

MS_REG_CPU_KERNEL(FusedAdaFactorWithGlobalNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16),
FusedAdaFactorCPUKernel)
} // namespace kernel
} // namespace mindspore



+ 2
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -215,6 +215,7 @@ constexpr auto kAdamWeightDecayName = "AdamWeightDecay";
constexpr auto kFusedCastAdamWeightDecayName = "FusedCastAdamWeightDecay";
constexpr auto kFusedAdamName = "FusedAdam";
constexpr auto kFusedAdaFactorName = "FusedAdaFactor";
constexpr auto kFusedAdaFactorWithGlobalNormName = "FusedAdaFactorWithGlobalNorm";
constexpr auto kFusedSparseAdamName = "FusedSparseAdam";
constexpr auto kFusedMatMulBiasAddName = "FusedMatMulBiasAdd";
constexpr auto kDeadNodeName = "DeadNode";
@@ -693,6 +694,7 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
kFusedCastAdamWeightDecayName,
kFusedAdamName,
kFusedAdaFactorName,
kFusedAdaFactorWithGlobalNormName,
kFusedSparseAdamName,
kFusedMulApplyMomentumOpName,
kFusedWeightScaleApplyMomentum,


+ 5
- 2
mindspore/python/mindspore/ops/operations/__init__.py View File

@@ -35,7 +35,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches)
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter, Broadcast,
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter,
Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
@@ -44,7 +45,8 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm
from .control_ops import GeSwitch, Merge
from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign,
MakeRefKey,
FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay, FusedAdaFactor)
FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay, FusedAdaFactor,
FusedAdaFactorWithGlobalNorm)

from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
BitwiseAnd, BitwiseOr, Ger,
@@ -178,6 +180,7 @@ __all__ = [
'FusedSparseLazyAdam',
'AdamNoUpdateParam',
'FusedAdaFactor',
'FusedAdaFactorWithGlobalNorm',
'Softplus',
'Softmax',
'Softsign',


+ 22
- 0
mindspore/python/mindspore/ops/operations/inner_ops.py View File

@@ -689,3 +689,25 @@ class FusedAdaFactor(PrimitiveWithInfer):
learning_rate_type, grad_type, param_type, exp_avg_type, exp_avg_sq_row_type,
exp_avg_sq_col_type, exp_avg_sq_type):
return param_type


class FusedAdaFactorWithGlobalNorm(FusedAdaFactor):
r"""
Divide global norm for gradient in FusedAdaFactor, and refer to super class for FusedAdaFactor details
"""

@prim_attr_register
def __init__(self, enable_scale_parameter=False, enable_first_moment=False, enable_weight_decay=False):
super(FusedAdaFactorWithGlobalNorm, self).__init__(enable_scale_parameter, enable_first_moment,
enable_weight_decay)

def infer_shape(self, epsilon_shape, clip_threshold_shape, beta1_shape, beta2t_shape, weight_decay_shape,
learning_rate_shape, grad_shape, param_shape, exp_avg_shape, exp_avg_sq_row_shape,
exp_avg_sq_col_shape, exp_avg_sq_shape, global_norm_shape):
validator.check("grad_shape", grad_shape, "param_shape", param_shape, Rel.EQ, self.name)
return param_shape

def infer_dtype(self, epsilon_type, clip_threshold_type, beta1_type, beta2t_type, weight_decay_type,
learning_rate_type, grad_type, param_type, exp_avg_type, exp_avg_sq_row_type,
exp_avg_sq_col_type, exp_avg_sq_type, global_norm_type):
return param_type

+ 28
- 0
tests/st/ops/cpu/test_fused_ada_factor_op.py View File

@@ -39,6 +39,17 @@ class Net(nn.Cell):
return out


class NetWithGlobalNorm(Net):
def __init__(self):
super(NetWithGlobalNorm, self).__init__()
self.opt = ops.FusedAdaFactorWithGlobalNorm()

def construct(self, epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad, global_norm):
out = self.opt(epsilon, clip_threshold, beta1, beta2, weight_decay, lr, grad, self.param, self.exp_avg,
self.exp_avg_sq_row, self.exp_avg_sq_col, self.exp_avg_sq, global_norm)
return out


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@@ -54,3 +65,20 @@ def test_adafactor():
net((1e-30, 1e-3), 1.0, 0.9, 0.8, 1e-2, 0.03, gradient)
diff = net.param.asnumpy() - np.ones(param_shape) * 0.97
assert np.all(diff < 1e-3)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adafactor_with_global_norm():
'''
Feature: AdaFactor
Description: Test AdaFactor
Expectation: Run success
'''
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
net = NetWithGlobalNorm()
gradient = Tensor(np.ones(param_shape), mstype.float32)
net((1e-30, 1e-3), 1.0, 0.9, 0.8, 1e-2, 0.03, gradient, 10.0)
diff = net.param.asnumpy() - np.ones(param_shape) * 0.97
assert np.all(diff < 1e-3)

+ 52
- 23
tests/ut/cpp/kernel/cpu/fused_ada_factor_cpu_kernel_test.cc View File

@@ -57,7 +57,7 @@ class FusedAdaFactorCpuKernelTest : public UT::Common {
grad_.resize(elem_num_);
exp_avg_.resize(elem_num_);
exp_avg_sq_.resize(elem_num_);
update_.resize(elem_num_);
update_.resize(elem_num_, 0.0f);
for (size_t i = 0; i < elem_num_; ++i) {
auto ptr = (float16 *)param_.data();
ptr[i] = static_cast<float16>(1.0f);
@@ -67,8 +67,6 @@ class FusedAdaFactorCpuKernelTest : public UT::Common {
ptr[i] = static_cast<float16>(0.0f);
ptr = (float16 *)exp_avg_sq_.data();
ptr[i] = static_cast<float16>(0.0f);
ptr = (float16 *)update_.data();
ptr[i] = static_cast<float16>(0.0f);
}

auto r_factor_num = elem_num_ / last_row_dim_size_;
@@ -77,8 +75,6 @@ class FusedAdaFactorCpuKernelTest : public UT::Common {
for (size_t i = 0; i < r_factor_num; ++i) {
auto ptr = (float16 *)exp_avg_sq_row_.data();
ptr[i] = static_cast<float16>(0.0f);
ptr = (float16 *)r_factor_.data();
ptr[i] = static_cast<float16>(0.0f);
}

auto c_factor_num = elem_num_ / last_col_dim_size_;
@@ -87,8 +83,6 @@ class FusedAdaFactorCpuKernelTest : public UT::Common {
for (size_t i = 0; i < c_factor_num; ++i) {
auto ptr = (float16 *)exp_avg_sq_col_.data();
ptr[i] = static_cast<float16>(0.0f);
ptr = (float16 *)c_factor_.data();
ptr[i] = static_cast<float16>(0.0f);
}
}

@@ -99,7 +93,7 @@ class FusedAdaFactorCpuKernelTest : public UT::Common {
return kernel_addr;
}

void CreateAddress() {
void CreateAddress(bool enable_global_norm) {
constexpr size_t eps_num = 2;
inputs_.push_back(CreateKernelAddress(epsilon_.data(), eps_num, kSizeFloat32));
inputs_.push_back(CreateKernelAddress(&clip_threshold_, 1, kSizeFloat32));
@@ -113,36 +107,37 @@ class FusedAdaFactorCpuKernelTest : public UT::Common {
inputs_.push_back(CreateKernelAddress(exp_avg_sq_row_.data(), elem_num_ / last_row_dim_size_, type_size_));
inputs_.push_back(CreateKernelAddress(exp_avg_sq_col_.data(), elem_num_ / last_col_dim_size_, type_size_));
inputs_.push_back(CreateKernelAddress(exp_avg_sq_.data(), elem_num_, type_size_));
workspace_.push_back(CreateKernelAddress(update_.data(), elem_num_, type_size_));
workspace_.push_back(CreateKernelAddress(r_factor_.data(), elem_num_ / last_row_dim_size_, type_size_));
workspace_.push_back(CreateKernelAddress(c_factor_.data(), elem_num_ / last_col_dim_size_, type_size_));
workspace_.push_back(CreateKernelAddress(update_.data(), elem_num_, kSizeFloat32));
workspace_.push_back(CreateKernelAddress(r_factor_.data(), elem_num_ / last_row_dim_size_, kSizeFloat32));
workspace_.push_back(CreateKernelAddress(c_factor_.data(), elem_num_ / last_col_dim_size_, kSizeFloat32));
if (enable_global_norm) {
inputs_.push_back(CreateKernelAddress(&global_norm_, 1, kSizeFloat32));
}
}

void ComputeFp32() {
void ComputeFp32(bool enable_global_norm) {
ada_factor_->param_dtype_ = kNumberTypeFloat32;
type_size_ = sizeof(float);
InitDataFp32();

CreateAddress();
CreateAddress(enable_global_norm);
ada_factor_->Launch(inputs_, workspace_, outputs_);

constexpr float result = 0.97;
for (size_t i = 0; i < elem_num_; ++i) {
EXPECT_TRUE(std::fabs(param_[i] - result) < 1e-6);
EXPECT_TRUE(std::fabs(param_[i] - result_) < 1e-6);
}
}

void ComputeFp16() {
void ComputeFp16(bool enable_global_norm) {
ada_factor_->param_dtype_ = kNumberTypeFloat16;
type_size_ = sizeof(float16);
InitDataFp16();

CreateAddress();
CreateAddress(enable_global_norm);
ada_factor_->Launch(inputs_, workspace_, outputs_);
constexpr float result = 0.97;
auto ptr = (float16 *)param_.data();
for (size_t i = 0; i < elem_num_; ++i) {
EXPECT_TRUE(std::fabs(static_cast<float>(ptr[i]) - result) < 1e-3);
EXPECT_TRUE(std::fabs(static_cast<float>(ptr[i]) - result_) < 1e-3);
}
}

@@ -152,6 +147,8 @@ class FusedAdaFactorCpuKernelTest : public UT::Common {
float beta1_ = 0.9;
float beta2t_ = 0.8;
float weight_decay_ = 1e-2;
float global_norm_ = 10.0f;
float result_ = 0.97;
std::vector<float> param_;
std::vector<float> grad_;
std::vector<float> exp_avg_;
@@ -179,7 +176,7 @@ class FusedAdaFactorCpuKernelTest : public UT::Common {
/// Expectation: pass
TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_factor) {
ada_factor_->need_factor_ = true;
ComputeFp32();
ComputeFp32(false);
}

/// Feature: FusedAdaFactor
@@ -187,7 +184,23 @@ TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_factor) {
/// Expectation: pass
TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_no_factor) {
ada_factor_->need_factor_ = false;
ComputeFp32();
ComputeFp32(false);
}

/// Feature: FusedAdaFactor
/// Description: Run FusedAdaFactor that needs factor state with fp32 data inputs and global norm
/// Expectation: pass
TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_factor_global_norm) {
ada_factor_->need_factor_ = true;
ComputeFp32(true);
}

/// Feature: FusedAdaFactor
/// Description: Run FusedAdaFactor that doesn't need factor state with fp32 data inputs and global norm
/// Expectation: pass
TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_no_factor_global_norm) {
ada_factor_->need_factor_ = false;
ComputeFp32(true);
}

/// Feature: FusedAdaFactor
@@ -195,7 +208,7 @@ TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_no_factor) {
/// Expectation: pass
TEST_F(FusedAdaFactorCpuKernelTest, compute_fp16_factor) {
ada_factor_->need_factor_ = true;
ComputeFp16();
ComputeFp16(false);
}

/// Feature: FusedAdaFactor
@@ -203,7 +216,23 @@ TEST_F(FusedAdaFactorCpuKernelTest, compute_fp16_factor) {
/// Expectation: pass
TEST_F(FusedAdaFactorCpuKernelTest, compute_fp16_no_factor) {
ada_factor_->need_factor_ = false;
ComputeFp16();
ComputeFp16(false);
}

/// Feature: FusedAdaFactor
/// Description: Run FusedAdaFactor that needs factor state with fp16 data inputs and global norm
/// Expectation: pass
TEST_F(FusedAdaFactorCpuKernelTest, compute_fp16_factor_global_norm) {
ada_factor_->need_factor_ = true;
ComputeFp16(true);
}

/// Feature: FusedAdaFactor
/// Description: Run FusedAdaFactor that doesn't need factor state with fp16 data inputs and global norm
/// Expectation: pass
TEST_F(FusedAdaFactorCpuKernelTest, compute_fp16_no_factor_global_norm) {
ada_factor_->need_factor_ = false;
ComputeFp16(true);
}
} // namespace kernel
} // namespace mindspore

Loading…
Cancel
Save