|
|
@@ -33,6 +33,8 @@ class TensorConstructUtils { |
|
|
static tensor::TensorPtr CreateZerosTensor(const TypePtr type, const std::vector<int64_t> &shape); |
|
|
static tensor::TensorPtr CreateZerosTensor(const TypePtr type, const std::vector<int64_t> &shape); |
|
|
static tensor::TensorPtr CreateOnesTensor(const TypePtr type, const std::vector<int64_t> &shape); |
|
|
static tensor::TensorPtr CreateOnesTensor(const TypePtr type, const std::vector<int64_t> &shape); |
|
|
static tensor::TensorPtr CreateTensor(const TypePtr type, const std::vector<int64_t> &shape, void *data); |
|
|
static tensor::TensorPtr CreateTensor(const TypePtr type, const std::vector<int64_t> &shape, void *data); |
|
|
|
|
|
|
|
|
|
|
|
private: |
|
|
static TypeId ExtractTypeId(const TypePtr type); |
|
|
static TypeId ExtractTypeId(const TypePtr type); |
|
|
}; |
|
|
}; |
|
|
} // namespace mindspore |
|
|
} // namespace mindspore |
|
|
|