Browse Source

fix static check

pull/15685/head
liuyu 5 years ago
parent
commit
4f6cf2ec3c
23 changed files with 69 additions and 116 deletions
  1. +3
    -4
      mindspore/core/ops/crop_and_resize.cc
  2. +3
    -7
      mindspore/core/ops/crop_and_resize.h
  3. +0
    -4
      mindspore/core/ops/erf.cc
  4. +0
    -3
      mindspore/core/ops/erf.h
  5. +6
    -6
      mindspore/core/ops/fusion/conv2d_transpose_fusion.cc
  6. +4
    -4
      mindspore/core/ops/fusion/conv2d_transpose_fusion.h
  7. +0
    -1
      mindspore/core/ops/grad/abs_grad.h
  8. +7
    -7
      mindspore/core/ops/grad/strided_slice_grad.cc
  9. +7
    -7
      mindspore/core/ops/grad/strided_slice_grad.h
  10. +13
    -13
      mindspore/core/ops/gru.cc
  11. +13
    -17
      mindspore/core/ops/gru.h
  12. +0
    -4
      mindspore/core/ops/invert_permutation.cc
  13. +0
    -3
      mindspore/core/ops/invert_permutation.h
  14. +0
    -1
      mindspore/core/ops/lin_space.h
  15. +0
    -4
      mindspore/core/ops/non_zero.cc
  16. +0
    -5
      mindspore/core/ops/non_zero.h
  17. +1
    -1
      mindspore/core/ops/op_utils.h
  18. +2
    -3
      mindspore/core/ops/random_standard_normal.cc
  19. +3
    -7
      mindspore/core/ops/random_standard_normal.h
  20. +0
    -4
      mindspore/core/ops/size.cc
  21. +0
    -3
      mindspore/core/ops/size.h
  22. +4
    -4
      mindspore/core/ops/uniform_real.cc
  23. +3
    -4
      mindspore/core/ops/uniform_real.h

+ 3
- 4
mindspore/core/ops/crop_and_resize.cc View File

@@ -15,7 +15,6 @@
*/

#include <set>
#include <vector>
#include <memory>
#include "ops/crop_and_resize.h"
#include "utils/check_convert_utils.h"
@@ -23,17 +22,17 @@

namespace mindspore {
namespace ops {
void CropAndResize::Init(const ResizeMethod method, const float extrapolation_value) {
void CropAndResize::Init(ResizeMethod method, float extrapolation_value) {
this->set_method(method);
this->set_extrapolation_value(extrapolation_value);
}

void CropAndResize::set_method(const ResizeMethod method) {
void CropAndResize::set_method(ResizeMethod method) {
auto swi = (int64_t)method;
this->AddAttr(kMethod, MakeValue(swi));
}

void CropAndResize::set_extrapolation_value(const float extrapolation_value) {
void CropAndResize::set_extrapolation_value(float extrapolation_value) {
this->AddAttr(kExtrapolationValue, MakeValue(extrapolation_value));
}



+ 3
- 7
mindspore/core/ops/crop_and_resize.h View File

@@ -30,17 +30,13 @@ class CropAndResize : public PrimitiveC {
CropAndResize() : PrimitiveC(kNameCropAndResize) { InitIOName({"x", "boxes", "box_index", "crop_size"}, {"y"}); }
~CropAndResize() = default;
MS_DECLARE_PARENT(CropAndResize, PrimitiveC);
void Init(const ResizeMethod method, const float extrapolation_value);

void set_method(const ResizeMethod method);
void set_extrapolation_value(const float extrapolation_value);
void Init(ResizeMethod method, float extrapolation_value);

void set_method(ResizeMethod method);
void set_extrapolation_value(float extrapolation_value);
ResizeMethod get_method() const;
float get_extrapolation_value() const;
};

AbstractBasePtr CropAndResizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimCropAndResizePtr = std::shared_ptr<CropAndResize>;
} // namespace ops
} // namespace mindspore


+ 0
- 4
mindspore/core/ops/erf.cc View File

@@ -14,12 +14,8 @@
* limitations under the License.
*/

#include <set>
#include <vector>
#include <memory>
#include "ops/erf.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"

namespace mindspore {
namespace ops {


+ 0
- 3
mindspore/core/ops/erf.h View File

@@ -30,11 +30,8 @@ class Erf : public PrimitiveC {
Erf() : PrimitiveC(kNameErf) { InitIOName({"x"}, {"y"}); }
~Erf() = default;
MS_DECLARE_PARENT(Erf, PrimitiveC);
void Init() {}
};

AbstractBasePtr ErfInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimErfPtr = std::shared_ptr<Erf>;
} // namespace ops
} // namespace mindspore


+ 6
- 6
mindspore/core/ops/fusion/conv2d_transpose_fusion.cc View File

@@ -23,7 +23,7 @@ void Conv2dTransposeFusion::Init(int64_t in_channel, int64_t out_channel, const
int64_t mode, const PadMode &pad_mode, const std::vector<int64_t> &pad,
const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation,
int64_t group, const Format &format, const std::vector<int64_t> &pad_list,
const std::vector<int64_t> &output_paddings, const ActivationType activation_type) {
const std::vector<int64_t> &output_paddings, ActivationType activation_type) {
set_in_channel(in_channel);
set_out_channel(out_channel);
set_kernel_size(kernel_size);
@@ -56,20 +56,20 @@ void Conv2dTransposeFusion::set_dilation(const std::vector<int64_t> &dilation) {
}

void Conv2dTransposeFusion::set_output_paddings(const std::vector<int64_t> &output_paddings) {
CheckAndConvertUtils::CheckInteger(koutputPaddings, output_paddings.size(), kGreaterEqual, 1, name());
CheckAndConvertUtils::CheckInteger(kOutputPaddings, output_paddings.size(), kGreaterEqual, 1, name());
for (int64_t item : output_paddings) {
CheckAndConvertUtils::CheckInteger(koutputPaddings, item, kGreaterEqual, 0, name());
CheckAndConvertUtils::CheckInteger(kOutputPaddings, item, kGreaterEqual, 0, name());
}
AddAttr(koutputPaddings, MakeValue(output_paddings));
AddAttr(kOutputPaddings, MakeValue(output_paddings));
}

void Conv2dTransposeFusion::set_activation_type(const ActivationType activation_type) {
void Conv2dTransposeFusion::set_activation_type(ActivationType activation_type) {
int64_t swi = activation_type;
this->AddAttr(kActivationType, MakeValue(swi));
}

std::vector<int64_t> Conv2dTransposeFusion::get_output_paddings() const {
auto value_ptr = GetAttr(koutputPaddings);
auto value_ptr = GetAttr(kOutputPaddings);
return GetValue<std::vector<int64_t>>(value_ptr);
}



+ 4
- 4
mindspore/core/ops/fusion/conv2d_transpose_fusion.h View File

@@ -36,11 +36,11 @@ class Conv2dTransposeFusion : public Conv2dTranspose {
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
const std::vector<int64_t> &stride = {1, 1}, const std::vector<int64_t> &dilation = {1, 1},
int64_t group = 1, const Format &format = NCHW, const std::vector<int64_t> &pad_list = {0, 0, 0, 0},
const std::vector<int64_t> &output_paddings = {0}, const ActivationType activation_type = NO_ACTIVATION);
void set_kernel_size(const std::vector<int64_t> &kernel_size);
void set_dilation(const std::vector<int64_t> &dilation);
const std::vector<int64_t> &output_paddings = {0}, ActivationType activation_type = NO_ACTIVATION);
void set_kernel_size(const std::vector<int64_t> &kernel_size) override;
void set_dilation(const std::vector<int64_t> &dilation) override;
void set_output_paddings(const std::vector<int64_t> &output_paddings);
void set_activation_type(const ActivationType activation_type);
void set_activation_type(ActivationType activation_type);

std::vector<int64_t> get_output_paddings() const;
ActivationType get_activation_type() const;


+ 0
- 1
mindspore/core/ops/grad/abs_grad.h View File

@@ -32,7 +32,6 @@ class AbsGrad : public PrimitiveC {
AbsGrad() : PrimitiveC(kNameAbsGrad) {}
~AbsGrad() = default;
MS_DECLARE_PARENT(AbsGrad, PrimitiveC);
void Init() {}
};
} // namespace ops
} // namespace mindspore


+ 7
- 7
mindspore/core/ops/grad/strided_slice_grad.cc View File

@@ -23,8 +23,8 @@

namespace mindspore {
namespace ops {
void StridedSliceGrad::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
void StridedSliceGrad::Init(int64_t begin_mask, int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask,
int64_t shrink_axis_mask) {
this->set_begin_mask(begin_mask);
this->set_end_mask(end_mask);
this->set_ellipsis_mask(ellipsis_mask);
@@ -32,7 +32,7 @@ void StridedSliceGrad::Init(const int64_t begin_mask, const int64_t end_mask, co
this->set_shrink_axis_mask(shrink_axis_mask);
}

void StridedSliceGrad::set_begin_mask(const int64_t begin_mask) {
void StridedSliceGrad::set_begin_mask(int64_t begin_mask) {
CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
this->AddAttr(kBeginMask, MakeValue(begin_mask));
}
@@ -40,7 +40,7 @@ int64_t StridedSliceGrad::get_begin_mask() const {
auto value_ptr = GetAttr(kBeginMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSliceGrad::set_end_mask(const int64_t end_mask) {
void StridedSliceGrad::set_end_mask(int64_t end_mask) {
CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
this->AddAttr(kEndMask, MakeValue(end_mask));
}
@@ -48,7 +48,7 @@ int64_t StridedSliceGrad::get_end_mask() const {
auto value_ptr = GetAttr(kEndMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSliceGrad::set_ellipsis_mask(const int64_t ellipsis_mask) {
void StridedSliceGrad::set_ellipsis_mask(int64_t ellipsis_mask) {
CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
std::ostringstream buffer;
@@ -62,7 +62,7 @@ int64_t StridedSliceGrad::get_ellipsis_mask() const {
auto value_ptr = GetAttr(kEllipsisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSliceGrad::set_new_axis_mask(const int64_t new_axis_mask) {
void StridedSliceGrad::set_new_axis_mask(int64_t new_axis_mask) {
CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
}
@@ -70,7 +70,7 @@ int64_t StridedSliceGrad::get_new_axis_mask() const {
auto value_ptr = GetAttr(kNewAxisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSliceGrad::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
void StridedSliceGrad::set_shrink_axis_mask(int64_t shrink_axis_mask) {
CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
}


+ 7
- 7
mindspore/core/ops/grad/strided_slice_grad.h View File

@@ -33,13 +33,13 @@ class StridedSliceGrad : public PrimitiveC {
StridedSliceGrad() : PrimitiveC(kNameStridedSliceGrad) {}
~StridedSliceGrad() = default;
MS_DECLARE_PARENT(StridedSliceGrad, PrimitiveC);
void Init(const int64_t begin_mask = 0, const int64_t end_mask = 0, const int64_t ellipsis_mask = 0,
const int64_t new_axis_mask = 0, const int64_t shrink_axis_mask = 0);
void set_begin_mask(const int64_t begin_mask);
void set_end_mask(const int64_t end_mask);
void set_ellipsis_mask(const int64_t ellipsis_mask);
void set_new_axis_mask(const int64_t new_axis_mask);
void set_shrink_axis_mask(const int64_t shrink_axis_mask);
void Init(int64_t begin_mask = 0, int64_t end_mask = 0, int64_t ellipsis_mask = 0, int64_t new_axis_mask = 0,
int64_t shrink_axis_mask = 0);
void set_begin_mask(int64_t begin_mask);
void set_end_mask(int64_t end_mask);
void set_ellipsis_mask(int64_t ellipsis_mask);
void set_new_axis_mask(int64_t new_axis_mask);
void set_shrink_axis_mask(int64_t shrink_axis_mask);
int64_t get_begin_mask() const;
int64_t get_end_mask() const;
int64_t get_ellipsis_mask() const;


+ 13
- 13
mindspore/core/ops/gru.cc View File

@@ -18,9 +18,9 @@

namespace mindspore {
namespace ops {
void GRU::Init(const bool bidirectional, const int64_t cell_depth, const float keep_prob, const float cell_clip,
const int64_t num_proj, const bool time_major, const bool reset_after, const bool is_training,
const ActivationType activation, const GateOrderMode gate_order) {
void GRU::Init(bool bidirectional, int64_t cell_depth, float keep_prob, float cell_clip, int64_t num_proj,
bool time_major, bool reset_after, bool is_training, ActivationType activation,
GateOrderMode gate_order) {
this->set_bidirectional(bidirectional);
this->set_cell_depth(cell_depth);
this->set_keep_prob(keep_prob);
@@ -33,31 +33,31 @@ void GRU::Init(const bool bidirectional, const int64_t cell_depth, const float k
this->set_gate_order(gate_order);
}

void GRU::set_bidirectional(const bool bidirectional) { AddAttr(kBidirectional, MakeValue(bidirectional)); }
void GRU::set_bidirectional(bool bidirectional) { AddAttr(kBidirectional, MakeValue(bidirectional)); }

void GRU::set_cell_depth(const int64_t cell_depth) { AddAttr(kCellDepth, MakeValue(cell_depth)); }
void GRU::set_cell_depth(int64_t cell_depth) { AddAttr(kCellDepth, MakeValue(cell_depth)); }

void GRU::set_keep_prob(const float keep_prob) { AddAttr(kKeepProb, MakeValue(keep_prob)); }
void GRU::set_keep_prob(float keep_prob) { AddAttr(kKeepProb, MakeValue(keep_prob)); }

void GRU::set_cell_clip(const float cell_clip) { AddAttr(kCellClip, MakeValue(cell_clip)); }
void GRU::set_cell_clip(float cell_clip) { AddAttr(kCellClip, MakeValue(cell_clip)); }

void GRU::set_num_proj(const int64_t num_proj) {
void GRU::set_num_proj(int64_t num_proj) {
CheckAndConvertUtils::CheckInteger(kNumProj, num_proj, kGreaterThan, 0, this->name());
AddAttr(kNumProj, MakeValue(num_proj));
}

void GRU::set_time_major(const bool time_major) { AddAttr(kTimeMajor, MakeValue(time_major)); }
void GRU::set_time_major(bool time_major) { AddAttr(kTimeMajor, MakeValue(time_major)); }

void GRU::set_reset_after(const bool reset_after) { AddAttr(kResetAfter, MakeValue(reset_after)); }
void GRU::set_reset_after(bool reset_after) { AddAttr(kResetAfter, MakeValue(reset_after)); }

void GRU::set_is_training(const bool is_training) { AddAttr(kIsTraining, MakeValue(is_training)); }
void GRU::set_is_training(bool is_training) { AddAttr(kIsTraining, MakeValue(is_training)); }

void GRU::set_activation(const ActivationType activation) {
void GRU::set_activation(ActivationType activation) {
int64_t swi = activation;
AddAttr(kActivation, MakeValue(swi));
}

void GRU::set_gate_order(const GateOrderMode gate_order) {
void GRU::set_gate_order(GateOrderMode gate_order) {
int64_t swi = gate_order;
AddAttr(kGateOrder, MakeValue(swi));
}


+ 13
- 17
mindspore/core/ops/gru.h View File

@@ -39,22 +39,20 @@ class GRU : public PrimitiveC {
}
~GRU() = default;
MS_DECLARE_PARENT(GRU, PrimitiveC);
void Init(const bool bidirectional = false, const int64_t cell_depth = 1, const float keep_prob = 1.0,
const float cell_clip = -1.0, const int64_t num_proj = 0, const bool time_major = true,
const bool reset_after = true, const bool is_training = true,
const ActivationType activation = ActivationType::TANH,
const GateOrderMode gate_order = GateOrderMode::RZH);
void Init(bool bidirectional = false, int64_t cell_depth = 1, float keep_prob = 1.0, float cell_clip = -1.0,
int64_t num_proj = 0, bool time_major = true, bool reset_after = true, bool is_training = true,
ActivationType activation = ActivationType::TANH, GateOrderMode gate_order = GateOrderMode::RZH);

void set_bidirectional(const bool bidirectional);
void set_cell_depth(const int64_t cell_depth);
void set_keep_prob(const float keep_prob);
void set_cell_clip(const float cell_clip);
void set_num_proj(const int64_t num_proj);
void set_time_major(const bool time_major);
void set_reset_after(const bool reset_after);
void set_is_training(const bool is_training);
void set_activation(const ActivationType activation);
void set_gate_order(const GateOrderMode gate_order);
void set_bidirectional(bool bidirectional);
void set_cell_depth(int64_t cell_depth);
void set_keep_prob(float keep_prob);
void set_cell_clip(float cell_clip);
void set_num_proj(int64_t num_proj);
void set_time_major(bool time_major);
void set_reset_after(bool reset_after);
void set_is_training(bool is_training);
void set_activation(ActivationType activation);
void set_gate_order(GateOrderMode gate_order);

bool get_bidirectional() const;
int64_t get_cell_depth() const;
@@ -68,8 +66,6 @@ class GRU : public PrimitiveC {
GateOrderMode get_gate_order() const;
};

AbstractBasePtr GRUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimGRUPtr = std::shared_ptr<GRU>;
} // namespace ops
} // namespace mindspore


+ 0
- 4
mindspore/core/ops/invert_permutation.cc View File

@@ -14,12 +14,8 @@
* limitations under the License.
*/

#include <set>
#include <vector>
#include <memory>
#include "ops/invert_permutation.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"

namespace mindspore {
namespace ops {


+ 0
- 3
mindspore/core/ops/invert_permutation.h View File

@@ -30,11 +30,8 @@ class InvertPermutation : public PrimitiveC {
InvertPermutation() : PrimitiveC(kNameInvertPermutation) {}
~InvertPermutation() = default;
MS_DECLARE_PARENT(InvertPermutation, PrimitiveC);
void Init() {}
};

AbstractBasePtr InvertPermutationInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimInvertPermutationPtr = std::shared_ptr<InvertPermutation>;
} // namespace ops
} // namespace mindspore


+ 0
- 1
mindspore/core/ops/lin_space.h View File

@@ -31,7 +31,6 @@ class LinSpace : public PrimitiveC {
LinSpace() : PrimitiveC(kNameLinSpace) { InitIOName({"start", "stop", "num"}, {"output"}); }
~LinSpace() = default;
MS_DECLARE_PARENT(LinSpace, PrimitiveC);
void Init() {}
};
} // namespace ops
} // namespace mindspore


+ 0
- 4
mindspore/core/ops/non_zero.cc View File

@@ -14,12 +14,8 @@
* limitations under the License.
*/

#include <set>
#include <vector>
#include <memory>
#include "ops/non_zero.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"

namespace mindspore {
namespace ops {


+ 0
- 5
mindspore/core/ops/non_zero.h View File

@@ -16,7 +16,6 @@

#ifndef MINDSPORE_CORE_OPS_NON_ZERO_H_
#define MINDSPORE_CORE_OPS_NON_ZERO_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
@@ -30,11 +29,7 @@ class NonZero : public PrimitiveC {
NonZero() : PrimitiveC(kNameNonZero) {}
~NonZero() = default;
MS_DECLARE_PARENT(NonZero, PrimitiveC);
void Init() {}
};

AbstractBasePtr NonZeroInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimNonZeroPtr = std::shared_ptr<NonZero>;
} // namespace ops
} // namespace mindspore


+ 1
- 1
mindspore/core/ops/op_utils.h View File

@@ -135,7 +135,7 @@ constexpr auto kOutChannel = "out_channel";
constexpr auto kOutMaxValue = "out_max_value";
constexpr auto kOutputChannel = "output_channel";
constexpr auto kOutputNum = "output_num";
constexpr auto koutputPaddings = "output_paddings";
constexpr auto kOutputPaddings = "output_paddings";
constexpr auto kOutputType = "output_type";
constexpr auto kOutQuantized = "out_quantized";
constexpr auto kP = "p";


+ 2
- 3
mindspore/core/ops/random_standard_normal.cc View File

@@ -16,7 +16,6 @@
#include "ops/random_standard_normal.h"
#include <string>
#include <memory>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"

@@ -27,9 +26,9 @@ void RandomStandardNormal::Init(const int64_t seed, const int64_t seed2) {
this->set_seed2(seed2);
}

void RandomStandardNormal::set_seed(const int64_t seed) { this->AddAttr(kSeed, MakeValue(seed)); }
void RandomStandardNormal::set_seed(int64_t seed) { this->AddAttr(kSeed, MakeValue(seed)); }

void RandomStandardNormal::set_seed2(const int64_t seed2) { this->AddAttr(kSeed2, MakeValue(seed2)); }
void RandomStandardNormal::set_seed2(int64_t seed2) { this->AddAttr(kSeed2, MakeValue(seed2)); }

int64_t RandomStandardNormal::get_seed() const {
auto value_ptr = GetAttr(kSeed);


+ 3
- 7
mindspore/core/ops/random_standard_normal.h View File

@@ -32,17 +32,13 @@ class RandomStandardNormal : public PrimitiveC {
RandomStandardNormal() : PrimitiveC(kNameRandomStandardNormal) {}
~RandomStandardNormal() = default;
MS_DECLARE_PARENT(RandomStandardNormal, PrimitiveC);
void Init(const int64_t seed, const int64_t seed2);

void set_seed(const int64_t seed);
void set_seed2(const int64_t seed2);
void Init(int64_t seed, int64_t seed2);

void set_seed(int64_t seed);
void set_seed2(int64_t seed2);
int64_t get_seed() const;
int64_t get_seed2() const;
};

AbstractBasePtr RandomStandardNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimRandomStandardNormalPtr = std::shared_ptr<RandomStandardNormal>;
} // namespace ops
} // namespace mindspore


+ 0
- 4
mindspore/core/ops/size.cc View File

@@ -14,12 +14,8 @@
* limitations under the License.
*/

#include <set>
#include <vector>
#include <memory>
#include "ops/size.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"

namespace mindspore {
namespace ops {


+ 0
- 3
mindspore/core/ops/size.h View File

@@ -30,11 +30,8 @@ class Size : public PrimitiveC {
Size() : PrimitiveC(kNameSize) {}
~Size() = default;
MS_DECLARE_PARENT(Size, PrimitiveC);
void Init() {}
};

AbstractBasePtr SizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimSizePtr = std::shared_ptr<Size>;
} // namespace ops
} // namespace mindspore


+ 4
- 4
mindspore/core/ops/uniform_real.cc View File

@@ -13,23 +13,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ops/uniform_real.h"
#include <string>
#include <memory>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace ops {
void UniformReal::Init(const int64_t seed, const int64_t seed2) {
void UniformReal::Init(int64_t seed, int64_t seed2) {
this->set_seed(seed);
this->set_seed2(seed2);
}

void UniformReal::set_seed(const int64_t seed) { this->AddAttr(kSeed, MakeValue(seed)); }
void UniformReal::set_seed(int64_t seed) { this->AddAttr(kSeed, MakeValue(seed)); }

void UniformReal::set_seed2(const int64_t seed2) { this->AddAttr(kSeed2, MakeValue(seed2)); }
void UniformReal::set_seed2(int64_t seed2) { this->AddAttr(kSeed2, MakeValue(seed2)); }

int64_t UniformReal::get_seed() const {
auto value_ptr = GetAttr(kSeed);


+ 3
- 4
mindspore/core/ops/uniform_real.h View File

@@ -32,11 +32,10 @@ class UniformReal : public PrimitiveC {
UniformReal() : PrimitiveC(kNameUniformReal) {}
~UniformReal() = default;
MS_DECLARE_PARENT(UniformReal, PrimitiveC);
void Init(const int64_t seed, const int64_t seed2);

void set_seed(const int64_t seed);
void set_seed2(const int64_t seed2);
void Init(int64_t seed, int64_t seed2);

void set_seed(int64_t seed);
void set_seed2(int64_t seed2);
int64_t get_seed() const;
int64_t get_seed2() const;
};


Loading…
Cancel
Save