Browse Source

!28197 add cpu fused adafactor

Merge pull request !28197 from kisnwang/add-cpu-adafactor
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
749917a819
6 changed files with 674 additions and 1 deletions
  1. +351
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/fused_ada_factor_cpu_kernel.cc
  2. +111
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/fused_ada_factor_cpu_kernel.h
  3. +0
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/fused_cast_adam_weight_decay_cpu_kernel.cc
  4. +2
    -0
      mindspore/ccsrc/utils/utils.h
  5. +1
    -0
      tests/ut/cpp/CMakeLists.txt
  6. +209
    -0
      tests/ut/cpp/kernel/cpu/fused_ada_factor_cpu_kernel_test.cc

+ 351
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/fused_ada_factor_cpu_kernel.cc View File

@@ -0,0 +1,351 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/fused_ada_factor_cpu_kernel.h"
#include <functional>
#include <algorithm>
#include "runtime/device/cpu/cpu_device_address.h"

namespace mindspore {
namespace kernel {
namespace {
static constexpr size_t kSizeFloat32 = sizeof(float);
static constexpr size_t kSizeFloat16 = sizeof(float16);
static constexpr size_t kScalarIndex = 0;
static constexpr size_t kFusedAdaFactorInputNum = 12;
static constexpr size_t kFusedAdaFactorWorkSpaceNum = 3;
static constexpr size_t kBatchSize = 10000;
static auto constexpr kEnableScaleParameter = "enable_scale_parameter";
static auto constexpr kEnableFirstMoment = "enable_first_moment";
static auto constexpr kEnableWeightDecay = "enable_weight_decay";
static constexpr size_t kLastRowIndex = 1;
static constexpr size_t kLastColIndex = 2;
static constexpr float kEps = 1e-30;
} // namespace

void FusedAdaFactorCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node);
if (param_dtype_ == kNumberTypeFloat16) {
(void)workspace_size_list_.emplace_back(elem_num_ * kSizeFloat16);
(void)workspace_size_list_.emplace_back(elem_num_ / last_row_dim_size_ * kSizeFloat16);
(void)workspace_size_list_.emplace_back(elem_num_ / last_col_dim_size_ * kSizeFloat16);
} else {
(void)workspace_size_list_.emplace_back(elem_num_ * kSizeFloat32);
(void)workspace_size_list_.emplace_back(elem_num_ / last_row_dim_size_ * kSizeFloat32);
(void)workspace_size_list_.emplace_back(elem_num_ / last_col_dim_size_ * kSizeFloat32);
}
}

void FusedAdaFactorCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
param_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, PARAM);
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, PARAM);
elem_num_ = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<size_t>());
if (elem_num_ < 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the elem num of 'param' should not be zero.";
}
if (shape.size() >= kLastColIndex) {
need_factor_ = true;
last_row_dim_size_ = shape[shape.size() - kLastRowIndex];
last_col_dim_size_ = shape[shape.size() - kLastColIndex];
if (last_row_dim_size_ < 1 || last_col_dim_size_ < 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the shape of 'param' should not be zero.";
}
}

if (AnfAlgo::HasNodeAttr(kEnableScaleParameter, kernel_node)) {
enable_scale_parameter_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, kEnableScaleParameter);
}
if (AnfAlgo::HasNodeAttr(kEnableFirstMoment, kernel_node)) {
enable_first_moment_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, kEnableFirstMoment);
}
if (AnfAlgo::HasNodeAttr(kEnableWeightDecay, kernel_node)) {
enable_weight_decay_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, kEnableWeightDecay);
}
}

template <typename T>
float FusedAdaFactorCPUKernel::CalcRMS(T *input, size_t elem_num) {
if (elem_num == 0) {
return 0.0f;
}

float rms = 0;
for (size_t i = 0; i < elem_num; ++i) {
auto tmp = static_cast<float>(input[i]);
rms += tmp * tmp;
}
rms /= elem_num;
return std::sqrt(rms);
}

template <typename T>
void FusedAdaFactorCPUKernel::FactorUpdate(T *update, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspaces) {
auto beta2t = reinterpret_cast<float *>(inputs[BETA2T]->addr)[kScalarIndex];
auto grad = reinterpret_cast<T *>(inputs[GRAD]->addr);
auto exp_avg_sq_row = reinterpret_cast<T *>(inputs[EXP_AVG_SQ_ROW]->addr);
auto exp_avg_sq_col = reinterpret_cast<T *>(inputs[EXP_AVG_SQ_COL]->addr);
auto r_factor = reinterpret_cast<T *>(workspaces[R_FACTOR]->addr);
auto c_factor = reinterpret_cast<T *>(workspaces[C_FACTOR]->addr);
auto one_minus_beta2t = 1 - beta2t;

std::function<void(size_t, size_t)> task;
size_t exp_avg_sq_row_elem_num = elem_num_ / last_row_dim_size_;
size_t exp_avg_sq_col_elem_num = elem_num_ / last_col_dim_size_;
size_t last_row_col_size = last_row_dim_size_ * last_col_dim_size_;
size_t row_dim_size = last_row_dim_size_;
size_t col_dim_size = last_col_dim_size_;
// exp_avg_sq_row = exp_avg_sq_row * beta2t + reduce_mean(update, -1) * one_minus_beta2t;
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
float row_reduce = 0;
size_t reduce_start = i * row_dim_size;
for (size_t j = 0; j < row_dim_size; ++j) {
row_reduce += static_cast<float>(update[reduce_start + j]);
}
row_reduce = row_reduce / row_dim_size;
auto tmp = static_cast<float>(exp_avg_sq_row[i]) * beta2t + row_reduce * one_minus_beta2t;
exp_avg_sq_row[i] = static_cast<T>(tmp);
}
};
CPUKernelUtils::ParallelFor(task, exp_avg_sq_row_elem_num, kBatchSize);

// r_factor = sqrt(exp_avg_sq_row / reduce_mean(exp_avg_sq_row, -1))
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
float col_reduce = 0;
size_t reduce_start = i * col_dim_size;
for (size_t j = 0; j < col_dim_size; ++j) {
col_reduce += static_cast<float>(exp_avg_sq_row[reduce_start + j]);
}
col_reduce /= col_dim_size;
col_reduce = std::max(col_reduce, kEps);
for (size_t j = 0; j < col_dim_size; ++j) {
auto tmp = std::sqrt(static_cast<float>(exp_avg_sq_row[reduce_start + j]) / col_reduce);
r_factor[reduce_start + j] = static_cast<T>(tmp);
}
}
};
CPUKernelUtils::ParallelFor(task, exp_avg_sq_row_elem_num / col_dim_size, kBatchSize);

// exp_avg_sq_col = exp_avg_sq_col * beta2t + reduce_mean(update, -2) * one_minus_beta2t;
// c_factor = sqrt(exp_avg_sq_col);
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
float row_reduce = 0;
size_t reduce_start = i / row_dim_size * last_row_col_size + i % row_dim_size;
for (size_t j = 0; j < col_dim_size; ++j) {
row_reduce += static_cast<float>(update[reduce_start + j * row_dim_size]);
}
row_reduce = row_reduce / col_dim_size;
auto tmp = static_cast<float>(exp_avg_sq_col[i]) * beta2t + row_reduce * one_minus_beta2t;
tmp = std::max(tmp, kEps);
exp_avg_sq_col[i] = static_cast<T>(tmp);
c_factor[i] = static_cast<T>(std::sqrt(tmp));
}
};
CPUKernelUtils::ParallelFor(task, exp_avg_sq_col_elem_num, kBatchSize);

// update = grad / (r_factor * c_factor);
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
size_t row_i = i % row_dim_size;
size_t col_i = i / row_dim_size % col_dim_size;
size_t slice = i / last_row_col_size;
auto left = static_cast<float>(r_factor[slice * col_dim_size + col_i]);
auto right = static_cast<float>(c_factor[slice * row_dim_size + row_i]);
auto norm = left * right;
norm = std::max(norm, kEps);
auto tmp = static_cast<float>(grad[i]) / norm;
update[i] = static_cast<T>(tmp);
}
};
CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
}

template <typename T>
void FusedAdaFactorCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspaces, const std::vector<AddressPtr> &) {
auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr);
auto clip_threshold = reinterpret_cast<float *>(inputs[CLIP_THRESHOLD]->addr)[kScalarIndex];
auto beta1 = reinterpret_cast<float *>(inputs[BETA1]->addr)[kScalarIndex];
auto beta2t = reinterpret_cast<float *>(inputs[BETA2T]->addr)[kScalarIndex];
auto weight_decay = reinterpret_cast<float *>(inputs[WEIGHT_DECAY]->addr)[kScalarIndex];
auto learning_rate = reinterpret_cast<float *>(inputs[LEARNING_RATE]->addr)[kScalarIndex];
auto grad = reinterpret_cast<T *>(inputs[GRAD]->addr);
auto param = reinterpret_cast<T *>(inputs[PARAM]->addr);
auto exp_avg = reinterpret_cast<T *>(inputs[EXP_AVG]->addr);
auto exp_avg_sq = reinterpret_cast<T *>(inputs[EXP_AVG_SQ]->addr);
auto update = reinterpret_cast<T *>(workspaces[UPDATE]->addr);
auto one_minus_beta1 = 1 - beta1;
auto one_minus_beta2t = 1 - beta2t;
if (clip_threshold <= 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', clip threshold " << clip_threshold << " is invalid. ";
}
if (beta1 < 0 || one_minus_beta1 < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', beta1 " << beta1 << " is invalid. ";
}
if (beta2t < 0 || one_minus_beta2t < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', beta2t " << beta2t << " is invalid. ";
}
if (epsilon[0] < 0 || epsilon[1] < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', epsilon (" << epsilon[0] << "," << epsilon[1]
<< ") is invalid. ";
}

std::function<void(size_t, size_t)> task;
// update = grad * grad + eps[0]
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
auto tmp = static_cast<float>(grad[i]);
update[i] = static_cast<T>(tmp * tmp + epsilon[0]);
}
};
CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);

if (need_factor_) {
FactorUpdate(update, inputs, workspaces);
} else {
// no factor
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
auto tmp = static_cast<float>(exp_avg_sq[i]) * beta2t + static_cast<float>(update[i]) * one_minus_beta2t;
tmp = std::max(tmp, kEps);
exp_avg_sq[i] = static_cast<T>(tmp);
tmp = static_cast<float>(grad[i]) / std::sqrt(static_cast<float>(exp_avg_sq[i]));
update[i] = static_cast<T>(tmp);
}
};
CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
}

// scale learning rate with rms of param
if (enable_scale_parameter_) {
auto rms = CalcRMS(param, elem_num_);
learning_rate = learning_rate * std::max(epsilon[1], rms);
}

// update param
auto update_rms_thres = CalcRMS(update, elem_num_) / clip_threshold;
auto update_coff = learning_rate / std::max(update_rms_thres, 1.0f);
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
auto tmp = static_cast<float>(update[i]) * update_coff;
update[i] = static_cast<T>(tmp);
if (enable_first_moment_) {
tmp = static_cast<float>(exp_avg[i]) * beta1 + static_cast<float>(update[i]) * one_minus_beta1;
exp_avg[i] = static_cast<T>(tmp);
update[i] = exp_avg[i];
}
if (enable_weight_decay_) {
tmp = static_cast<float>(param[i]) * weight_decay * learning_rate;
param[i] = param[i] - update[i] - static_cast<T>(tmp);
} else {
param[i] = param[i] - update[i];
}
}
};
CPUKernelUtils::ParallelFor(task, elem_num_, kBatchSize);
}

bool FusedAdaFactorCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspaces,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != kFusedAdaFactorInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kFusedAdaFactorInputNum
<< ", but got: " << inputs.size();
}
if (workspaces.size() != kFusedAdaFactorWorkSpaceNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of workspaces should be "
<< kFusedAdaFactorWorkSpaceNum << ", but got: " << workspaces.size();
}
CheckParam(inputs, workspaces, outputs);
if (param_dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, workspaces, outputs);
} else {
LaunchKernel<float>(inputs, workspaces, outputs);
}
return true;
}

void FusedAdaFactorCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspaces,
const std::vector<kernel::AddressPtr> &) const {
if (inputs[EPSILON]->size != kSizeFloat32 << 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'epsilon' should be " << (kSizeFloat32 << 1)
<< ", but got " << inputs[EPSILON]->size;
}
if (inputs[CLIP_THRESHOLD]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta1' should be " << kSizeFloat32
<< ", but got " << inputs[BETA1]->size;
}

if (inputs[BETA1]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta1' should be " << kSizeFloat32
<< ", but got " << inputs[BETA1]->size;
}
if (inputs[BETA2T]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta2t' should be " << kSizeFloat32
<< ", but got " << inputs[BETA2T]->size;
}
if (inputs[WEIGHT_DECAY]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'weight_decay' should be " << kSizeFloat32
<< ", but got " << inputs[WEIGHT_DECAY]->size;
}
if (inputs[LEARNING_RATE]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'lr' should be " << kSizeFloat32
<< ", but got " << inputs[LEARNING_RATE]->size;
}

size_t param_size = param_dtype_ == kNumberTypeFloat16 ? elem_num_ * kSizeFloat16 : elem_num_ * kSizeFloat32;
if (inputs[PARAM]->size != param_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'param' should be " << param_size
<< ", but got " << inputs[PARAM]->size;
}
if (inputs[GRAD]->size != param_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'gradient' should be " << param_size
<< ", but got " << inputs[GRAD]->size;
}
if (workspaces[UPDATE]->size != param_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'update ' should be " << param_size
<< ", but got " << workspaces[0]->size;
}

if (enable_first_moment_ && inputs[EXP_AVG]->size != param_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg' should be " << param_size
<< ", but got " << inputs[EXP_AVG]->size;
}

if (!need_factor_) {
if (inputs[EXP_AVG_SQ]->size != param_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg_sq' should be " << param_size
<< ", but got " << inputs[EXP_AVG_SQ]->size;
}
return;
}

if (inputs[EXP_AVG_SQ_ROW]->size != param_size / last_row_dim_size_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg_sq_row' should be "
<< param_size / last_row_dim_size_ << ", but got " << inputs[EXP_AVG_SQ_ROW]->size;
}
if (inputs[EXP_AVG_SQ_COL]->size != param_size / last_col_dim_size_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg_sq_col' should be "
<< param_size / last_col_dim_size_ << ", but got " << inputs[EXP_AVG_SQ_COL]->size;
}
}
} // namespace kernel
} // namespace mindspore

+ 111
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/fused_ada_factor_cpu_kernel.h View File

@@ -0,0 +1,111 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_ADA_FACTOR_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_ADA_FACTOR_CPU_KERNEL_H_

#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"

namespace mindspore {
namespace kernel {
class FusedAdaFactorCPUKernel : public CPUKernel {
public:
FusedAdaFactorCPUKernel() = default;
~FusedAdaFactorCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
void InitInputOutputSize(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs) override;

private:
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs) const;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs);

template <typename T>
float CalcRMS(T *input, size_t elem_num);

template <typename T>
void FactorUpdate(T *update, const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces);

bool enable_scale_parameter_{false};
bool enable_first_moment_{false};
bool enable_weight_decay_{false};
bool need_factor_{false};
size_t elem_num_{0};
size_t last_row_dim_size_{0};
size_t last_col_dim_size_{0};
TypeId param_dtype_{kTypeUnknown};

enum InputEnum {
EPSILON,
CLIP_THRESHOLD,
BETA1,
BETA2T,
WEIGHT_DECAY,
LEARNING_RATE,
GRAD,
PARAM,
EXP_AVG,
EXP_AVG_SQ_ROW,
EXP_AVG_SQ_COL,
EXP_AVG_SQ
};

enum WorkspaceEnum { UPDATE, R_FACTOR, C_FACTOR };
};

MS_REG_CPU_KERNEL(FusedAdaFactor,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedAdaFactorCPUKernel)

MS_REG_CPU_KERNEL(FusedAdaFactor,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
FusedAdaFactorCPUKernel)
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_ADA_FACTOR_CPU_KERNEL_H_

+ 0
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/fused_cast_adam_weight_decay_cpu_kernel.cc View File

@@ -15,7 +15,6 @@
*/
#include "backend/kernel_compiler/cpu/fused_cast_adam_weight_decay_cpu_kernel.h"
#include <cmath>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "nnacl/fp32/adam_fp32.h"
#include "utils/ms_utils.h"


+ 2
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -215,6 +215,7 @@ constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay";
constexpr auto kAdamWeightDecayName = "AdamWeightDecay";
constexpr auto kFusedCastAdamWeightDecayName = "FusedCastAdamWeightDecay";
constexpr auto kFusedAdamName = "FusedAdam";
constexpr auto kFusedAdaFactorName = "FusedAdaFactor";
constexpr auto kFusedSparseAdamName = "FusedSparseAdam";
constexpr auto kFusedMatMulBiasAddName = "FusedMatMulBiasAdd";
constexpr auto kDeadNodeName = "DeadNode";
@@ -689,6 +690,7 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
kAdamWeightDecayName,
kFusedCastAdamWeightDecayName,
kFusedAdamName,
kFusedAdaFactorName,
kFusedSparseAdamName,
kFusedMulApplyMomentumOpName,
kFusedWeightScaleApplyMomentum,


+ 1
- 0
tests/ut/cpp/CMakeLists.txt View File

@@ -151,6 +151,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.cc"
"../../../mindspore/ccsrc/backend/kernel_compiler/cpu/unique_with_pad_cpu_kernel.cc"
"../../../mindspore/ccsrc/backend/kernel_compiler/cpu/adam_delta_cpu_kernel.cc"
"../../../mindspore/ccsrc/backend/kernel_compiler/cpu/fused_ada_factor_cpu_kernel.cc"
"../../../mindspore/ccsrc/backend/kernel_compiler/akg/*.cc"
"../../../mindspore/ccsrc/backend/kernel_compiler/rts/*.cc"
"../../../mindspore/ccsrc/backend/kernel_compiler/hccl/*.cc"


+ 209
- 0
tests/ut/cpp/kernel/cpu/fused_ada_factor_cpu_kernel_test.cc View File

@@ -0,0 +1,209 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <vector>
#include "common/common_test.h"
#define private public
#define protected public
#include "backend/kernel_compiler/cpu/fused_ada_factor_cpu_kernel.h"
#undef private
#undef protected

namespace mindspore {
namespace kernel {
static constexpr size_t kSizeFloat32 = sizeof(float);
class FusedAdaFactorCpuKernelTest : public UT::Common {
public:
FusedAdaFactorCpuKernelTest() : ada_factor_(std::make_shared<FusedAdaFactorCPUKernel>()) {}

void SetUp() override {
ada_factor_->elem_num_ = elem_num_;
ada_factor_->kernel_name_ = "AdaFactorTest";
ada_factor_->last_row_dim_size_ = last_row_dim_size_;
ada_factor_->last_col_dim_size_ = last_col_dim_size_;
}

void InitDataFp32() {
exp_avg_.resize(elem_num_, 0.0f);
exp_avg_sq_.resize(elem_num_, 0.0f);
update_.resize(elem_num_, 0.0f);
param_.resize(elem_num_, 1.0f);
grad_.resize(elem_num_, 1.0f);

auto r_factor_num = elem_num_ / last_row_dim_size_;
exp_avg_sq_row_.resize(r_factor_num, 0.0f);
r_factor_.resize(r_factor_num, 0.0f);

auto c_factor_num = elem_num_ / last_col_dim_size_;
exp_avg_sq_col_.resize(c_factor_num, 0.0f);
c_factor_.resize(c_factor_num, 0.0f);
}

void InitDataFp16() {
param_.resize(elem_num_);
grad_.resize(elem_num_);
exp_avg_.resize(elem_num_);
exp_avg_sq_.resize(elem_num_);
update_.resize(elem_num_);
for (size_t i = 0; i < elem_num_; ++i) {
auto ptr = (float16 *)param_.data();
ptr[i] = static_cast<float16>(1.0f);
ptr = (float16 *)grad_.data();
ptr[i] = static_cast<float16>(1.0f);
ptr = (float16 *)exp_avg_.data();
ptr[i] = static_cast<float16>(0.0f);
ptr = (float16 *)exp_avg_sq_.data();
ptr[i] = static_cast<float16>(0.0f);
ptr = (float16 *)update_.data();
ptr[i] = static_cast<float16>(0.0f);
}

auto r_factor_num = elem_num_ / last_row_dim_size_;
exp_avg_sq_row_.resize(r_factor_num, 0.0f);
r_factor_.resize(r_factor_num, 0.0f);
for (size_t i = 0; i < r_factor_num; ++i) {
auto ptr = (float16 *)exp_avg_sq_row_.data();
ptr[i] = static_cast<float16>(0.0f);
ptr = (float16 *)r_factor_.data();
ptr[i] = static_cast<float16>(0.0f);
}

auto c_factor_num = elem_num_ / last_col_dim_size_;
exp_avg_sq_col_.resize(c_factor_num, 0.0f);
c_factor_.resize(c_factor_num, 0.0f);
for (size_t i = 0; i < c_factor_num; ++i) {
auto ptr = (float16 *)exp_avg_sq_col_.data();
ptr[i] = static_cast<float16>(0.0f);
ptr = (float16 *)c_factor_.data();
ptr[i] = static_cast<float16>(0.0f);
}
}

AddressPtr CreateKernelAddress(void *addr, size_t elem_num, size_t type_size) {
auto kernel_addr = std::make_shared<Address>();
kernel_addr->addr = addr;
kernel_addr->size = elem_num * type_size;
return kernel_addr;
}

void CreateAddress() {
constexpr size_t eps_num = 2;
inputs_.push_back(CreateKernelAddress(epsilon_.data(), eps_num, kSizeFloat32));
inputs_.push_back(CreateKernelAddress(&clip_threshold_, 1, kSizeFloat32));
inputs_.push_back(CreateKernelAddress(&beta1_, 1, kSizeFloat32));
inputs_.push_back(CreateKernelAddress(&beta2t_, 1, kSizeFloat32));
inputs_.push_back(CreateKernelAddress(&weight_decay_, 1, kSizeFloat32));
inputs_.push_back(CreateKernelAddress(&lr_, 1, kSizeFloat32));
inputs_.push_back(CreateKernelAddress(grad_.data(), elem_num_, type_size_));
inputs_.push_back(CreateKernelAddress(param_.data(), elem_num_, type_size_));
inputs_.push_back(CreateKernelAddress(exp_avg_.data(), elem_num_, type_size_));
inputs_.push_back(CreateKernelAddress(exp_avg_sq_row_.data(), elem_num_ / last_row_dim_size_, type_size_));
inputs_.push_back(CreateKernelAddress(exp_avg_sq_col_.data(), elem_num_ / last_col_dim_size_, type_size_));
inputs_.push_back(CreateKernelAddress(exp_avg_sq_.data(), elem_num_, type_size_));
workspace_.push_back(CreateKernelAddress(update_.data(), elem_num_, type_size_));
workspace_.push_back(CreateKernelAddress(r_factor_.data(), elem_num_ / last_row_dim_size_, type_size_));
workspace_.push_back(CreateKernelAddress(c_factor_.data(), elem_num_ / last_col_dim_size_, type_size_));
}

void ComputeFp32() {
ada_factor_->param_dtype_ = kNumberTypeFloat32;
type_size_ = sizeof(float);
InitDataFp32();

CreateAddress();
ada_factor_->Launch(inputs_, workspace_, outputs_);

constexpr float result = 0.97;
for (size_t i = 0; i < elem_num_; ++i) {
EXPECT_TRUE(std::fabs(param_[i] - result) < 1e-6);
}
}

void ComputeFp16() {
ada_factor_->param_dtype_ = kNumberTypeFloat16;
type_size_ = sizeof(float16);
InitDataFp16();

CreateAddress();
ada_factor_->Launch(inputs_, workspace_, outputs_);
constexpr float result = 0.97;
auto ptr = (float16 *)param_.data();
for (size_t i = 0; i < elem_num_; ++i) {
EXPECT_TRUE(std::fabs(static_cast<float>(ptr[i]) - result) < 1e-3);
}
}

std::vector<float> epsilon_{1e-30, 1e-3};
float clip_threshold_ = 1.0;
float lr_ = 0.03;
float beta1_ = 0.9;
float beta2t_ = 0.8;
float weight_decay_ = 1e-2;
std::vector<float> param_;
std::vector<float> grad_;
std::vector<float> exp_avg_;
std::vector<float> exp_avg_sq_row_;
std::vector<float> exp_avg_sq_col_;
std::vector<float> exp_avg_sq_;

std::vector<float> update_;
std::vector<float> r_factor_;
std::vector<float> c_factor_;

std::vector<AddressPtr> inputs_;
std::vector<AddressPtr> workspace_;
std::vector<AddressPtr> outputs_;
std::shared_ptr<FusedAdaFactorCPUKernel> ada_factor_;

size_t last_row_dim_size_ = 4;
size_t last_col_dim_size_ = 6;
size_t elem_num_ = 2 * 6 * 4;
size_t type_size_ = 4;
};

/// Feature: FusedAdaFactor
/// Description: Run FusedAdaFactor that needs factor state with fp32 data inputs
/// Expectation: pass
TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_factor) {
ada_factor_->need_factor_ = true;
ComputeFp32();
}

/// Feature: FusedAdaFactor
/// Description: Run FusedAdaFactor that doesn't need factor state with fp32 data inputs
/// Expectation: pass
TEST_F(FusedAdaFactorCpuKernelTest, compute_fp32_no_factor) {
ada_factor_->need_factor_ = false;
ComputeFp32();
}

/// Feature: FusedAdaFactor
/// Description: Run FusedAdaFactor that needs factor state with fp16 data inputs
/// Expectation: pass
TEST_F(FusedAdaFactorCpuKernelTest, compute_fp16_factor) {
ada_factor_->need_factor_ = true;
ComputeFp16();
}

/// Feature: FusedAdaFactor
/// Description: Run FusedAdaFactor that doesn't need factor state with fp16 data inputs
/// Expectation: pass
TEST_F(FusedAdaFactorCpuKernelTest, compute_fp16_no_factor) {
ada_factor_->need_factor_ = false;
ComputeFp16();
}
} // namespace kernel
} // namespace mindspore

Loading…
Cancel
Save