Browse Source

first commit - updated op + STs

lint fix
tags/v1.1.0
danishnxt 5 years ago
parent
commit
55c455df9e
3 changed files with 37 additions and 4 deletions
  1. +11
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_ftrl_impl.cu
  2. +22
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_ftrl_gpu_kernel.cc
  3. +4
    -4
      tests/st/ops/gpu/test_sparse_apply_ftrl_op.py

+ 11
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_ftrl_impl.cu View File

@@ -96,8 +96,19 @@ template void CalSparseApplyFtrl<float, int>(const float *gradient, const int *i
const float l1_regularization, const float l2_regularization,
const float learning_rate_power, const bool use_locking, float *variable,
float *accumulation, float *linear, cudaStream_t cuda_stream);
template void CalSparseApplyFtrl<float, int64_t>(const float *gradient, const int64_t *indices, const int num_index,
const size_t n_stride, const float learning_rate,
const float l1_regularization, const float l2_regularization,
const float learning_rate_power, const bool use_locking, float *variable,
float *accumulation, float *linear, cudaStream_t cuda_stream);
template void CalSparseApplyFtrl<half, int>(const half *gradient, const int *indices, const int num_index,
const size_t n_stride, const float learning_rate,
const float l1_regularization, const float l2_regularization,
const float learning_rate_power, const bool use_locking, half *variable,
half *accumulation, half *linear, cudaStream_t cuda_stream);
template void CalSparseApplyFtrl<half, int64_t>(const half *gradient, const int64_t *indices, const int num_index,
const size_t n_stride, const float learning_rate,
const float l1_regularization, const float l2_regularization,
const float learning_rate_power, const bool use_locking, half *variable,
half *accumulation, half *linear, cudaStream_t cuda_stream);


+ 22
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_ftrl_gpu_kernel.cc View File

@@ -29,6 +29,17 @@ MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl,
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SparseFtrlGpuKernel, float, int)
MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
SparseFtrlGpuKernel, float, int64_t)
MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
@@ -40,5 +51,16 @@ MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl,
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SparseFtrlGpuKernel, half, int)
MS_REG_GPU_KERNEL_TWO(SparseApplyFtrl,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
SparseFtrlGpuKernel, half, int64_t)
} // namespace kernel
} // namespace mindspore

+ 4
- 4
tests/st/ops/gpu/test_sparse_apply_ftrl_op.py View File

@@ -77,9 +77,9 @@ def test_ftrl():
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ftrl_sparse():
def test_ftrl_sparse_int64_ind():
gradient = Tensor(np.ones([2, 3, 3]).astype(np.float32))
indices = Tensor([0, 2], mstype.int32)
indices = Tensor([0, 2], mstype.int64)
expect_var = np.array([[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]],
@@ -127,9 +127,9 @@ def test_ftrl_half():
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ftrl_sparse_half():
def test_ftrl_sparse_half_int64_ind():
gradient = Tensor(np.ones([2, 3, 3]).astype(np.float16))
indices = Tensor([0, 2], mstype.int32)
indices = Tensor([0, 2], mstype.int64)
expect_var = np.array([[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]],


Loading…
Cancel
Save