|
|
|
@@ -21,18 +21,34 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
namespace { |
|
|
|
static constexpr size_t kSizeFloat32 = sizeof(float); |
|
|
|
static constexpr size_t kSizeFloat16 = sizeof(float16); |
|
|
|
static constexpr size_t kScalarIndex = 0; |
|
|
|
static constexpr size_t kStandardInputNum = 12; |
|
|
|
static constexpr size_t kWorkSpaceNum = 3; |
|
|
|
static constexpr size_t kBatchSize = 1000; |
|
|
|
static auto constexpr kEnableScaleParameter = "enable_scale_parameter"; |
|
|
|
static auto constexpr kEnableFirstMoment = "enable_first_moment"; |
|
|
|
static auto constexpr kEnableWeightDecay = "enable_weight_decay"; |
|
|
|
static constexpr size_t kLastRowIndex = 1; |
|
|
|
static constexpr size_t kLastColIndex = 2; |
|
|
|
static constexpr float kEps = 1e-30; |
|
|
|
constexpr size_t kSizeFloat32 = sizeof(float); |
|
|
|
constexpr size_t kSizeFloat16 = sizeof(float16); |
|
|
|
constexpr size_t kScalarIndex = 0; |
|
|
|
constexpr size_t kStandardInputNum = 12; |
|
|
|
constexpr size_t kWorkSpaceNum = 3; |
|
|
|
constexpr size_t kBatchSize = 1000; |
|
|
|
auto constexpr kEnableScaleParameter = "enable_scale_parameter"; |
|
|
|
auto constexpr kEnableFirstMoment = "enable_first_moment"; |
|
|
|
auto constexpr kEnableWeightDecay = "enable_weight_decay"; |
|
|
|
constexpr size_t kLastRowIndex = 1; |
|
|
|
constexpr size_t kLastColIndex = 2; |
|
|
|
constexpr float kEps = 1e-30; |
|
|
|
constexpr size_t kEpsIndex = 0; |
|
|
|
constexpr size_t kClipThresholdIndex = 1; |
|
|
|
constexpr size_t kBeta1Index = 2; |
|
|
|
constexpr size_t kBeta2tIndex = 3; |
|
|
|
constexpr size_t kWeightDecayIndex = 4; |
|
|
|
constexpr size_t kLearningRateIndex = 5; |
|
|
|
constexpr size_t kGradIndex = 6; |
|
|
|
constexpr size_t kParamIndex = 7; |
|
|
|
constexpr size_t kExpAvgIndex = 8; |
|
|
|
constexpr size_t kExpAvgSQRowIndex = 9; |
|
|
|
constexpr size_t kExpAvgSQColIndex = 10; |
|
|
|
constexpr size_t kExpAvgSQIndex = 11; |
|
|
|
constexpr size_t kGlobalNormIndex = 12; |
|
|
|
constexpr size_t kWorkSpaceUpdateIndex = 0; |
|
|
|
constexpr size_t kWorkSpaceRFactorIndex = 1; |
|
|
|
constexpr size_t kWorkSpaceCFactorIndex = 2; |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void FusedAdaFactorCpuKernelMod::InitInputOutputSize(const CNodePtr &kernel_node) { |
|
|
|
@@ -45,9 +61,9 @@ void FusedAdaFactorCpuKernelMod::InitInputOutputSize(const CNodePtr &kernel_node |
|
|
|
void FusedAdaFactorCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); |
|
|
|
param_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, PARAM); |
|
|
|
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, PARAM); |
|
|
|
elem_num_ = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<size_t>()); |
|
|
|
param_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kParamIndex); |
|
|
|
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, kParamIndex); |
|
|
|
elem_num_ = std::accumulate(shape.begin(), shape.end(), 1UL, std::multiplies<size_t>()); |
|
|
|
if (elem_num_ < 1) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the elem num of 'param' should not be zero."; |
|
|
|
} |
|
|
|
@@ -98,19 +114,19 @@ float FusedAdaFactorCpuKernelMod::CalcRMS(T *input, size_t elem_num) { |
|
|
|
} |
|
|
|
(void)common::ThreadPool::GetInstance().SyncRun(tasks); |
|
|
|
auto rms = std::accumulate(block_sum.begin(), block_sum.end(), 0.0f); |
|
|
|
rms /= elem_num; |
|
|
|
rms = rms / elem_num; |
|
|
|
return std::sqrt(rms); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void FusedAdaFactorCpuKernelMod::FactorUpdate(float *update, const std::vector<AddressPtr> &inputs, |
|
|
|
const std::vector<AddressPtr> &workspaces) { |
|
|
|
auto beta2t = reinterpret_cast<float *>(inputs[BETA2T]->addr)[kScalarIndex]; |
|
|
|
auto grad = reinterpret_cast<T *>(inputs[GRAD]->addr); |
|
|
|
auto exp_avg_sq_row = reinterpret_cast<T *>(inputs[EXP_AVG_SQ_ROW]->addr); |
|
|
|
auto exp_avg_sq_col = reinterpret_cast<T *>(inputs[EXP_AVG_SQ_COL]->addr); |
|
|
|
auto r_factor = reinterpret_cast<float *>(workspaces[R_FACTOR]->addr); |
|
|
|
auto c_factor = reinterpret_cast<float *>(workspaces[C_FACTOR]->addr); |
|
|
|
auto beta2t = reinterpret_cast<float *>(inputs[kBeta2tIndex]->addr)[kScalarIndex]; |
|
|
|
auto grad = reinterpret_cast<T *>(inputs[kGradIndex]->addr); |
|
|
|
auto exp_avg_sq_row = reinterpret_cast<T *>(inputs[kExpAvgSQRowIndex]->addr); |
|
|
|
auto exp_avg_sq_col = reinterpret_cast<T *>(inputs[kExpAvgSQColIndex]->addr); |
|
|
|
auto r_factor = reinterpret_cast<float *>(workspaces[kWorkSpaceRFactorIndex]->addr); |
|
|
|
auto c_factor = reinterpret_cast<float *>(workspaces[kWorkSpaceCFactorIndex]->addr); |
|
|
|
auto one_minus_beta2t = 1 - beta2t; |
|
|
|
|
|
|
|
std::function<void(size_t, size_t)> task; |
|
|
|
@@ -119,7 +135,7 @@ void FusedAdaFactorCpuKernelMod::FactorUpdate(float *update, const std::vector<A |
|
|
|
size_t last_row_col_size = last_row_dim_size_ * last_col_dim_size_; |
|
|
|
size_t row_dim_size = last_row_dim_size_; |
|
|
|
size_t col_dim_size = last_col_dim_size_; |
|
|
|
// step 1: exp_avg_sq_row = exp_avg_sq_row * beta2t + reduce_mean(update, -1) * one_minus_beta2t; |
|
|
|
// calc exp_avg_sq_row |
|
|
|
task = [&](size_t start, size_t end) { |
|
|
|
for (size_t i = start; i < end; ++i) { |
|
|
|
float row_reduce = 0; |
|
|
|
@@ -134,7 +150,7 @@ void FusedAdaFactorCpuKernelMod::FactorUpdate(float *update, const std::vector<A |
|
|
|
}; |
|
|
|
CPUKernelUtils::ParallelFor(task, exp_avg_sq_row_elem_num, kBatchSize); |
|
|
|
|
|
|
|
// step 2: r_factor = sqrt(exp_avg_sq_row / reduce_mean(exp_avg_sq_row, -1)) |
|
|
|
// calc r_factor |
|
|
|
task = [&](size_t start, size_t end) { |
|
|
|
for (size_t i = start; i < end; ++i) { |
|
|
|
float col_reduce = 0; |
|
|
|
@@ -142,7 +158,7 @@ void FusedAdaFactorCpuKernelMod::FactorUpdate(float *update, const std::vector<A |
|
|
|
for (size_t j = 0; j < col_dim_size; ++j) { |
|
|
|
col_reduce += static_cast<float>(exp_avg_sq_row[reduce_start + j]); |
|
|
|
} |
|
|
|
col_reduce /= col_dim_size; |
|
|
|
col_reduce = col_reduce / col_dim_size; |
|
|
|
col_reduce = std::max(col_reduce, kEps); |
|
|
|
for (size_t j = 0; j < col_dim_size; ++j) { |
|
|
|
r_factor[reduce_start + j] = std::sqrt(static_cast<float>(exp_avg_sq_row[reduce_start + j]) / col_reduce); |
|
|
|
@@ -151,8 +167,7 @@ void FusedAdaFactorCpuKernelMod::FactorUpdate(float *update, const std::vector<A |
|
|
|
}; |
|
|
|
CPUKernelUtils::ParallelFor(task, exp_avg_sq_row_elem_num / col_dim_size, kBatchSize); |
|
|
|
|
|
|
|
// step 3: exp_avg_sq_col = exp_avg_sq_col * beta2t + reduce_mean(update, -2) * one_minus_beta2t; |
|
|
|
// step 4: c_factor = sqrt(exp_avg_sq_col); |
|
|
|
// calc exp_avg_sq_col and c_factor |
|
|
|
task = [&](size_t start, size_t end) { |
|
|
|
for (size_t i = start; i < end; ++i) { |
|
|
|
float row_reduce = 0; |
|
|
|
@@ -169,7 +184,7 @@ void FusedAdaFactorCpuKernelMod::FactorUpdate(float *update, const std::vector<A |
|
|
|
}; |
|
|
|
CPUKernelUtils::ParallelFor(task, exp_avg_sq_col_elem_num, kBatchSize); |
|
|
|
|
|
|
|
// step 5: update = grad / (r_factor * c_factor); |
|
|
|
// calc update |
|
|
|
task = [&](size_t start, size_t end) { |
|
|
|
for (size_t i = start; i < end; ++i) { |
|
|
|
size_t row_i = i % row_dim_size; |
|
|
|
@@ -186,17 +201,17 @@ template <typename T> |
|
|
|
void FusedAdaFactorCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
const std::vector<AddressPtr> &workspaces, |
|
|
|
const std::vector<AddressPtr> &) { |
|
|
|
auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr); |
|
|
|
auto clip_threshold = reinterpret_cast<float *>(inputs[CLIP_THRESHOLD]->addr)[kScalarIndex]; |
|
|
|
auto beta1 = reinterpret_cast<float *>(inputs[BETA1]->addr)[kScalarIndex]; |
|
|
|
auto beta2t = reinterpret_cast<float *>(inputs[BETA2T]->addr)[kScalarIndex]; |
|
|
|
auto weight_decay = reinterpret_cast<float *>(inputs[WEIGHT_DECAY]->addr)[kScalarIndex]; |
|
|
|
auto learning_rate = reinterpret_cast<float *>(inputs[LEARNING_RATE]->addr)[kScalarIndex]; |
|
|
|
auto grad = reinterpret_cast<T *>(inputs[GRAD]->addr); |
|
|
|
auto param = reinterpret_cast<T *>(inputs[PARAM]->addr); |
|
|
|
auto exp_avg = reinterpret_cast<T *>(inputs[EXP_AVG]->addr); |
|
|
|
auto exp_avg_sq = reinterpret_cast<T *>(inputs[EXP_AVG_SQ]->addr); |
|
|
|
auto update = reinterpret_cast<float *>(workspaces[UPDATE]->addr); |
|
|
|
auto epsilon = reinterpret_cast<float *>(inputs[kEpsIndex]->addr); |
|
|
|
auto clip_threshold = reinterpret_cast<float *>(inputs[kClipThresholdIndex]->addr)[kScalarIndex]; |
|
|
|
auto beta1 = reinterpret_cast<float *>(inputs[kBeta1Index]->addr)[kScalarIndex]; |
|
|
|
auto beta2t = reinterpret_cast<float *>(inputs[kBeta2tIndex]->addr)[kScalarIndex]; |
|
|
|
auto weight_decay = reinterpret_cast<float *>(inputs[kWeightDecayIndex]->addr)[kScalarIndex]; |
|
|
|
auto learning_rate = reinterpret_cast<float *>(inputs[kLearningRateIndex]->addr)[kScalarIndex]; |
|
|
|
auto grad = reinterpret_cast<T *>(inputs[kGradIndex]->addr); |
|
|
|
auto param = reinterpret_cast<T *>(inputs[kParamIndex]->addr); |
|
|
|
auto exp_avg = reinterpret_cast<T *>(inputs[kExpAvgIndex]->addr); |
|
|
|
auto exp_avg_sq = reinterpret_cast<T *>(inputs[kExpAvgSQIndex]->addr); |
|
|
|
auto update = reinterpret_cast<float *>(workspaces[kWorkSpaceUpdateIndex]->addr); |
|
|
|
auto one_minus_beta1 = 1 - beta1; |
|
|
|
auto one_minus_beta2t = 1 - beta2t; |
|
|
|
if (clip_threshold <= 0) { |
|
|
|
@@ -245,8 +260,9 @@ void FusedAdaFactorCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inp |
|
|
|
} |
|
|
|
|
|
|
|
// update param |
|
|
|
auto update_rms_thres = CalcRMS(update, elem_num_) / clip_threshold; |
|
|
|
auto update_coff = learning_rate / std::max(update_rms_thres, 1.0f); |
|
|
|
auto update_rms = CalcRMS(update, elem_num_); |
|
|
|
auto update_rms_threshold = update_rms / clip_threshold; |
|
|
|
auto update_coff = learning_rate / std::max(update_rms_threshold, 1.0f); |
|
|
|
task = [&](size_t start, size_t end) { |
|
|
|
for (size_t i = start; i < end; ++i) { |
|
|
|
update[i] = update[i] * update_coff; |
|
|
|
@@ -255,8 +271,8 @@ void FusedAdaFactorCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inp |
|
|
|
exp_avg[i] = static_cast<T>(update[i]); |
|
|
|
} |
|
|
|
if (enable_weight_decay_) { |
|
|
|
auto tmp = static_cast<float>(param[i]) * weight_decay * learning_rate; |
|
|
|
param[i] = static_cast<T>(static_cast<float>(param[i]) - update[i] - tmp); |
|
|
|
auto tmp = update[i] + static_cast<float>(param[i]) * weight_decay * learning_rate; |
|
|
|
param[i] = static_cast<T>(static_cast<float>(param[i]) - tmp); |
|
|
|
} else { |
|
|
|
param[i] = static_cast<T>(static_cast<float>(param[i]) - update[i]); |
|
|
|
} |
|
|
|
@@ -269,7 +285,7 @@ bool FusedAdaFactorCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &i |
|
|
|
const std::vector<kernel::AddressPtr> &workspaces, |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) { |
|
|
|
if (inputs.size() == kStandardInputNum + 1) { |
|
|
|
auto global_norm = reinterpret_cast<float *>(inputs[GLOBAL_NORM]->addr)[kScalarIndex]; |
|
|
|
auto global_norm = reinterpret_cast<float *>(inputs[kGlobalNormIndex]->addr)[kScalarIndex]; |
|
|
|
if (global_norm < kEps) { |
|
|
|
global_norm_reciprocal_ = 1.0f; |
|
|
|
} else { |
|
|
|
@@ -293,61 +309,61 @@ void FusedAdaFactorCpuKernelMod::CheckInputAddresses(const std::vector<kernel::A |
|
|
|
<< ", but got: " << inputs.size(); |
|
|
|
} |
|
|
|
|
|
|
|
if (inputs[EPSILON]->size != kSizeFloat32 << 1) { |
|
|
|
if (inputs[kEpsIndex]->size != kSizeFloat32 << 1) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'epsilon' should be " << (kSizeFloat32 << 1) |
|
|
|
<< ", but got " << inputs[EPSILON]->size; |
|
|
|
<< ", but got " << inputs[kEpsIndex]->size; |
|
|
|
} |
|
|
|
if (inputs[CLIP_THRESHOLD]->size != kSizeFloat32) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta1' should be " << kSizeFloat32 |
|
|
|
<< ", but got " << inputs[BETA1]->size; |
|
|
|
if (inputs[kClipThresholdIndex]->size != kSizeFloat32) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'clip_threshold' should be " << kSizeFloat32 |
|
|
|
<< ", but got " << inputs[kClipThresholdIndex]->size; |
|
|
|
} |
|
|
|
if (inputs[BETA1]->size != kSizeFloat32) { |
|
|
|
if (inputs[kBeta1Index]->size != kSizeFloat32) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta1' should be " << kSizeFloat32 |
|
|
|
<< ", but got " << inputs[BETA1]->size; |
|
|
|
<< ", but got " << inputs[kBeta1Index]->size; |
|
|
|
} |
|
|
|
if (inputs[BETA2T]->size != kSizeFloat32) { |
|
|
|
if (inputs[kBeta2tIndex]->size != kSizeFloat32) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'beta2t' should be " << kSizeFloat32 |
|
|
|
<< ", but got " << inputs[BETA2T]->size; |
|
|
|
<< ", but got " << inputs[kBeta2tIndex]->size; |
|
|
|
} |
|
|
|
if (inputs[WEIGHT_DECAY]->size != kSizeFloat32) { |
|
|
|
if (inputs[kWeightDecayIndex]->size != kSizeFloat32) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'weight_decay' should be " << kSizeFloat32 |
|
|
|
<< ", but got " << inputs[WEIGHT_DECAY]->size; |
|
|
|
<< ", but got " << inputs[kWeightDecayIndex]->size; |
|
|
|
} |
|
|
|
if (inputs[LEARNING_RATE]->size != kSizeFloat32) { |
|
|
|
if (inputs[kLearningRateIndex]->size != kSizeFloat32) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'lr' should be " << kSizeFloat32 |
|
|
|
<< ", but got " << inputs[LEARNING_RATE]->size; |
|
|
|
<< ", but got " << inputs[kLearningRateIndex]->size; |
|
|
|
} |
|
|
|
|
|
|
|
size_t param_size = param_dtype_ == kNumberTypeFloat16 ? elem_num_ * kSizeFloat16 : elem_num_ * kSizeFloat32; |
|
|
|
if (inputs[PARAM]->size != param_size) { |
|
|
|
if (inputs[kParamIndex]->size != param_size) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'param' should be " << param_size |
|
|
|
<< ", but got " << inputs[PARAM]->size; |
|
|
|
<< ", but got " << inputs[kParamIndex]->size; |
|
|
|
} |
|
|
|
if (inputs[GRAD]->size != param_size) { |
|
|
|
if (inputs[kGradIndex]->size != param_size) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'gradient' should be " << param_size |
|
|
|
<< ", but got " << inputs[GRAD]->size; |
|
|
|
<< ", but got " << inputs[kGradIndex]->size; |
|
|
|
} |
|
|
|
|
|
|
|
if (enable_first_moment_ && inputs[EXP_AVG]->size != param_size) { |
|
|
|
if (enable_first_moment_ && inputs[kExpAvgIndex]->size != param_size) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg' should be " << param_size |
|
|
|
<< ", but got " << inputs[EXP_AVG]->size; |
|
|
|
<< ", but got " << inputs[kExpAvgIndex]->size; |
|
|
|
} |
|
|
|
|
|
|
|
if (!need_factor_) { |
|
|
|
if (inputs[EXP_AVG_SQ]->size != param_size) { |
|
|
|
if (inputs[kExpAvgSQIndex]->size != param_size) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg_sq' should be " << param_size |
|
|
|
<< ", but got " << inputs[EXP_AVG_SQ]->size; |
|
|
|
<< ", but got " << inputs[kExpAvgSQIndex]->size; |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
if (inputs[EXP_AVG_SQ_ROW]->size != param_size / last_row_dim_size_) { |
|
|
|
if (inputs[kExpAvgSQRowIndex]->size != param_size / last_row_dim_size_) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg_sq_row' should be " |
|
|
|
<< param_size / last_row_dim_size_ << ", but got " << inputs[EXP_AVG_SQ_ROW]->size; |
|
|
|
<< param_size / last_row_dim_size_ << ", but got " << inputs[kExpAvgSQRowIndex]->size; |
|
|
|
} |
|
|
|
if (inputs[EXP_AVG_SQ_COL]->size != param_size / last_col_dim_size_) { |
|
|
|
if (inputs[kExpAvgSQColIndex]->size != param_size / last_col_dim_size_) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'exp_avg_sq_col' should be " |
|
|
|
<< param_size / last_col_dim_size_ << ", but got " << inputs[EXP_AVG_SQ_COL]->size; |
|
|
|
<< param_size / last_col_dim_size_ << ", but got " << inputs[kExpAvgSQColIndex]->size; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -358,19 +374,18 @@ void FusedAdaFactorCpuKernelMod::CheckWorkspaceAddresses(const std::vector<kerne |
|
|
|
} |
|
|
|
|
|
|
|
size_t update_size = elem_num_ * kSizeFloat32; |
|
|
|
|
|
|
|
if (workspaces[UPDATE]->size != elem_num_ * kSizeFloat32) { |
|
|
|
if (workspaces[kWorkSpaceUpdateIndex]->size != elem_num_ * kSizeFloat32) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'update ' should be " << update_size |
|
|
|
<< ", but got " << workspaces[0]->size; |
|
|
|
<< ", but got " << workspaces[kWorkSpaceUpdateIndex]->size; |
|
|
|
} |
|
|
|
|
|
|
|
if (workspaces[R_FACTOR]->size != update_size / last_row_dim_size_) { |
|
|
|
if (workspaces[kWorkSpaceRFactorIndex]->size != update_size / last_row_dim_size_) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'r_factor' should be " |
|
|
|
<< update_size / last_row_dim_size_ << ", but got " << workspaces[R_FACTOR]->size; |
|
|
|
<< update_size / last_row_dim_size_ << ", but got " << workspaces[kWorkSpaceRFactorIndex]->size; |
|
|
|
} |
|
|
|
if (workspaces[C_FACTOR]->size != update_size / last_col_dim_size_) { |
|
|
|
if (workspaces[kWorkSpaceCFactorIndex]->size != update_size / last_col_dim_size_) { |
|
|
|
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'c_factor' should be " |
|
|
|
<< update_size / last_col_dim_size_ << ", but got " << workspaces[C_FACTOR]->size; |
|
|
|
<< update_size / last_col_dim_size_ << ", but got " << workspaces[kWorkSpaceCFactorIndex]->size; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|