|
|
|
@@ -34,11 +34,11 @@ void SparseApplyLazyAdamPSKernelMod::InitKernel( |
|
|
|
MS_LOG(EXCEPTION) << "SparseApplyLazyAdamPSKernelMod needs " << kSparseApplyLazyAdamPSInputsSize |
|
|
|
<< " input shapes, but got " << shape_vec.size(); |
|
|
|
} |
|
|
|
std::vector<size_t> &var_shape = *(shape_vec[0]); |
|
|
|
std::vector<size_t> &m_shape = *(shape_vec[1]); |
|
|
|
std::vector<size_t> &v_shape = *(shape_vec[2]); |
|
|
|
const std::vector<size_t> &grad_shape = *(shape_vec[9]); |
|
|
|
const std::vector<size_t> &indices_shape = *(shape_vec[10]); |
|
|
|
std::vector<size_t> &var_shape = *(shape_vec[var_index_]); |
|
|
|
std::vector<size_t> &m_shape = *(shape_vec[m_index_]); |
|
|
|
std::vector<size_t> &v_shape = *(shape_vec[v_index_]); |
|
|
|
const std::vector<size_t> &grad_shape = *(shape_vec[grad_index_]); |
|
|
|
const std::vector<size_t> &indices_shape = *(shape_vec[indices_index_]); |
|
|
|
|
|
|
|
Shard(&var_shape, 0); |
|
|
|
Shard(&m_shape, 0); |
|
|
|
@@ -94,7 +94,7 @@ void SparseApplyLazyAdamPSKernelMod::ReInit(const std::vector<AddressPtr> &input |
|
|
|
MS_LOG(EXCEPTION) << "Input shape size should not be less than " << kSparseApplyLazyAdamPSInputsSize << ", but got " |
|
|
|
<< inputs.size(); |
|
|
|
} |
|
|
|
const auto &indices_addr = inputs[10]; |
|
|
|
const auto &indices_addr = inputs[indices_index_]; |
|
|
|
indices_size_ = indices_addr->size / sizeof(int); |
|
|
|
workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_; |
|
|
|
workspace_size_list_[1] = indices_size_ * sizeof(int) * worker_num_; |
|
|
|
|