Browse Source

fix pool grad

tags/v1.2.0-rc1
xuanyue 5 years ago
parent
commit
46142c4519
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      mindspore/core/ops/grad/pool_grad.cc

+ 3
- 3
mindspore/core/ops/grad/pool_grad.cc View File

@@ -55,8 +55,8 @@ void PoolGrad::Init(const std::vector<int64_t> &kernel_size, const std::vector<i
}

void PoolGrad::set_kernel_size(const std::vector<int64_t> &kernel_size) {
std::vector<int64_t> k_size = _grad_check_vector(kSize, kernel_size, this->name());
this->AddAttr(kSize, MakeValue(k_size));
std::vector<int64_t> k_size = _grad_check_vector(kKernelSize, kernel_size, this->name());
this->AddAttr(kKernelSize, MakeValue(k_size));
}

void PoolGrad::set_strides(const std::vector<int64_t> &strides) {
@@ -75,7 +75,7 @@ void PoolGrad::set_format(const Format &format) {
}

std::vector<int64_t> PoolGrad::get_kernel_size() const {
auto value_ptr = GetAttr(kSize);
auto value_ptr = GetAttr(kKernelSize);
return GetValue<std::vector<int64_t>>(value_ptr);
}



Loading…
Cancel
Save