Merge pull request !2245 from hewei/decouple_tensortags/v0.6.0-beta
| @@ -27,6 +27,7 @@ | |||
| #include "utils/symbolic.h" | |||
| #include "ir/meta_func_graph.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "ir/tensor_py.h" | |||
| #include "pipeline/parse/python_adapter.h" | |||
| #include "pipeline/parse/resolve.h" | |||
| #include "operator/composite/composite.h" | |||
| @@ -39,6 +40,8 @@ | |||
| #include "utils/context/ms_context.h" | |||
| #include "operator/ops.h" | |||
| using mindspore::tensor::TensorPy; | |||
| namespace mindspore { | |||
| // max number of elements in sequence | |||
| const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF; | |||
| @@ -399,7 +402,7 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const Valu | |||
| oss << value->DumpText(); | |||
| } else if (value->isa<tensor::Tensor>()) { | |||
| auto tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||
| oss << value->DumpText() << "@" << DumpObject(tensor_ptr->data(), "T"); | |||
| oss << value->DumpText() << "@" << DumpObject(TensorPy::AsNumpy(*tensor_ptr), "T"); | |||
| } else if (value->isa<parse::Symbol>() || value->isa<None>() || value->isa<NullObj>()) { | |||
| oss << value->DumpText(); | |||
| } else if (value->isa<ValueSequeue>()) { | |||
| @@ -1813,7 +1816,7 @@ class IrParser { | |||
| if (tensor_data == nullptr) { | |||
| return TOK_ERROR; | |||
| } | |||
| *val_ptr = std::make_shared<tensor::Tensor>(tensor_data, TypeIdToType(type)); | |||
| *val_ptr = TensorPy::MakeTensor(tensor_data, TypeIdToType(type)); | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| @@ -117,7 +117,7 @@ void DebugServices::check_watchpoints(std::vector<std::string> *name, std::vecto | |||
| continue; | |||
| } | |||
| float *start_addr = reinterpret_cast<float *>(tensor_ptr->data_c(false)); | |||
| float *start_addr = reinterpret_cast<float *>(tensor_ptr->data_c()); | |||
| unsigned int num_elements = (tensor_ptr->data().nbytes()) / sizeof(float); | |||
| std::unordered_map<unsigned int, watchpoint_t>::iterator it_w_table_check; | |||
| @@ -144,7 +144,7 @@ void DebugServices::check_watchpoints(std::vector<std::string> *name, std::vecto | |||
| name->push_back(name_no_slot); | |||
| slot->push_back(std::to_string(tensor_list[i]->GetSlot())); | |||
| data_ptr->push_back(reinterpret_cast<char *>(tensor_ptr->data_c(false))); | |||
| data_ptr->push_back(reinterpret_cast<char *>(tensor_ptr->data_c())); | |||
| data_size->push_back(tensor_ptr->data().nbytes()); | |||
| int condition_item = -1; | |||
| @@ -182,7 +182,7 @@ void DebugServices::read_nodes_tensors(std::vector<std::string> name, std::vecto | |||
| continue; | |||
| } | |||
| ret_name->push_back(std::get<0>(result)); | |||
| data_ptr->push_back(reinterpret_cast<char *>(std::get<1>(result)->GetTensor()->data_c(false))); | |||
| data_ptr->push_back(reinterpret_cast<char *>(std::get<1>(result)->GetTensor()->data_c())); | |||
| data_size->push_back(std::get<1>(result)->GetTensor()->data().nbytes()); | |||
| dtype->push_back(std::get<1>(result)->GetTensor()->Dtype()); | |||
| shape->push_back(std::get<1>(result)->GetTensor()->shape()); | |||
| @@ -329,12 +329,12 @@ bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &file | |||
| MS_LOG(INFO) << "E2E Dump path is " << path; | |||
| mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(host_type, host_shape); | |||
| size_t host_size = out_tensor->data().nbytes(); | |||
| ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c(true)); | |||
| ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c()); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Copy device mem to host failed"; | |||
| return ret; | |||
| } | |||
| ret = mindspore::Dump::DumpToFile(path, out_tensor->data_c(false), host_size); | |||
| ret = mindspore::Dump::DumpToFile(path, out_tensor->data_c(), host_size); | |||
| } else { | |||
| auto host_tmp = std::vector<uint8_t>(size_); | |||
| auto ret_rt_memcpy = rtMemcpy(host_tmp.data(), size_, ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); | |||
| @@ -364,7 +364,7 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens | |||
| MS_LOG(INFO) << "E2E tensor name is " << tensor_name; | |||
| mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(host_type, host_shape); | |||
| size_t host_size = out_tensor->data().nbytes(); | |||
| ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c(true)); | |||
| ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c()); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Copy device mem to host failed"; | |||
| return ret; | |||
| @@ -379,7 +379,7 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens | |||
| } else { | |||
| mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(type_id_, host_shape); | |||
| size_t host_size = out_tensor->data().nbytes(); | |||
| auto ret_rt_memcpy = rtMemcpy(out_tensor->data_c(true), host_size, ptr_, host_size, RT_MEMCPY_DEVICE_TO_HOST); | |||
| auto ret_rt_memcpy = rtMemcpy(out_tensor->data_c(), host_size, ptr_, host_size, RT_MEMCPY_DEVICE_TO_HOST); | |||
| auto tensor_data = std::make_shared<mindspore::TensorData>(); | |||
| tensor_data->SetName(tensor_name); | |||
| @@ -81,11 +81,11 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph | |||
| DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeFloat32); | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { | |||
| address->ptr_ = tensor->data_c(false); | |||
| address->ptr_ = tensor->data_c(); | |||
| } else { | |||
| address->ptr_ = resource_manager_.MemMalloc(tensor_size); | |||
| if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c(false))) { | |||
| tensor->data_c())) { | |||
| MS_LOG(EXCEPTION) << "Value node sync host to device failed!"; | |||
| } | |||
| } | |||
| @@ -178,7 +178,7 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &k | |||
| tensor->set_device_address(address); | |||
| need_sync_outputs->emplace_back(tensor); | |||
| } else { | |||
| address->ptr_ = tensor->data_c(true); | |||
| address->ptr_ = tensor->data_c(); | |||
| address->ref_count_ = INIT_NODE_REF; | |||
| (void)bound_addresses->insert(address); | |||
| } | |||
| @@ -221,11 +221,11 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, | |||
| size_t tensor_size = | |||
| std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies<size_t>()); | |||
| if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { | |||
| address->ptr_ = tensor->data_c(false); | |||
| address->ptr_ = tensor->data_c(); | |||
| } else { | |||
| address->ptr_ = resource_manager_.MemMalloc(tensor_size); | |||
| if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c(false))) { | |||
| tensor->data_c())) { | |||
| MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!"; | |||
| } | |||
| tensor->set_dirty(true); | |||
| @@ -390,7 +390,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph | |||
| tensor->set_device_address(device_address); | |||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c(false))) { | |||
| tensor->data_c())) { | |||
| MS_LOG(INFO) << "SyncHostToDevice failed."; | |||
| return false; | |||
| } | |||
| @@ -407,14 +407,14 @@ void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) { | |||
| tensor::TensorPtr loop_count_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | |||
| MS_EXCEPTION_IF_NULL(loop_count_tensor); | |||
| int32_t *val = nullptr; | |||
| val = static_cast<int32_t *>(loop_count_tensor->data_c(true)); | |||
| val = static_cast<int32_t *>(loop_count_tensor->data_c()); | |||
| MS_EXCEPTION_IF_NULL(val); | |||
| *val = 0; | |||
| inputs->push_back(loop_count_tensor); | |||
| tensor::TensorPtr iter_loop_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | |||
| MS_EXCEPTION_IF_NULL(iter_loop_tensor); | |||
| val = static_cast<int32_t *>(iter_loop_tensor->data_c(true)); | |||
| val = static_cast<int32_t *>(iter_loop_tensor->data_c()); | |||
| MS_EXCEPTION_IF_NULL(val); | |||
| *val = SizeToInt(LongToSize(ConfigManager::GetInstance().iter_num())); | |||
| MS_LOG(INFO) << "iter_loop_tensor = " << *val; | |||
| @@ -422,14 +422,14 @@ void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) { | |||
| tensor::TensorPtr zero_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | |||
| MS_EXCEPTION_IF_NULL(zero_tensor); | |||
| val = static_cast<int32_t *>(zero_tensor->data_c(true)); | |||
| val = static_cast<int32_t *>(zero_tensor->data_c()); | |||
| MS_EXCEPTION_IF_NULL(val); | |||
| *val = 0; | |||
| inputs->push_back(zero_tensor); | |||
| tensor::TensorPtr one_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | |||
| MS_EXCEPTION_IF_NULL(one_tensor); | |||
| val = static_cast<int32_t *>(one_tensor->data_c(true)); | |||
| val = static_cast<int32_t *>(one_tensor->data_c()); | |||
| MS_EXCEPTION_IF_NULL(val); | |||
| *val = 1; | |||
| inputs->push_back(one_tensor); | |||
| @@ -543,7 +543,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const | |||
| } | |||
| AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); | |||
| if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), | |||
| tensor->data_c(false))) { | |||
| tensor->data_c())) { | |||
| MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" | |||
| << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is " | |||
| << AnfAlgo::GetOutputInferDataType(value_node, output_idx); | |||
| @@ -115,7 +115,7 @@ class MetaTensor : public Value { | |||
| // order it represents. | |||
| // | |||
| // return A const vector<int> which represents the shape of the tensor. | |||
| std::vector<int> shape() const { return shape_; } | |||
| const std::vector<int> &shape() const { return shape_; } | |||
| // brief Sets the shape of a tensor. | |||
| // | |||
| @@ -16,319 +16,261 @@ | |||
| #include "ir/tensor.h" | |||
| #include <atomic> | |||
| #include <functional> | |||
| #include <numeric> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "device/device_address.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| namespace mindspore { | |||
| namespace tensor { | |||
| static uint64_t count = 0; | |||
| void DataBuf2Contiguous(const py::array &src, py::array *const dest) { | |||
| if (dest == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!"; | |||
| } | |||
| Py_buffer pybuf_src; | |||
| if (PyObject_GetBuffer(src.ptr(), &pybuf_src, PyBUF_ANY_CONTIGUOUS)) { | |||
| MS_LOG(EXCEPTION) << "Failed to get buffer info from the src!"; | |||
| } | |||
| using Bool = unsigned char; | |||
| if (!PyBuffer_IsContiguous(&pybuf_src, 'C')) { | |||
| if (PyBuffer_ToContiguous(dest->request(true).ptr, &pybuf_src, pybuf_src.len, 'C')) { | |||
| MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer."; | |||
| } | |||
| } else { | |||
| *dest = src; | |||
| } | |||
| PyBuffer_Release(&pybuf_src); | |||
| static std::string MakeId() { | |||
| // Use atomic to make id generator thread safe. | |||
| static std::atomic<uint64_t> last_id{1}; | |||
| return std::to_string(last_id.fetch_add(1, std::memory_order_relaxed)); | |||
| } | |||
| Tensor::Tensor(const TypePtr &type_ptr, const py::tuple &shape) { | |||
| TypeId data_type = TypeId::kTypeUnknown; | |||
| if (type_ptr != nullptr) { | |||
| data_type = type_ptr->type_id(); | |||
| } | |||
| data_type_ = data_type; | |||
| shape_.resize(shape.size()); | |||
| for (size_t i = 0; i < shape.size(); ++i) { | |||
| shape_[i] = py::int_(shape[i]); | |||
| } | |||
| init(data_type_, shape_, &data_); | |||
| static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) { | |||
| return data_type ? data_type->type_id() : defaultTypeId; | |||
| } | |||
| Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) { init(data_type, shape, &data_); } | |||
| Tensor::Tensor(const py::array &input, const TypePtr &data_type) { init(input, data_type); } | |||
| Tensor::Tensor(const py::list &input, const TypePtr &data_type) { init(py::array(input), data_type); } | |||
| Tensor::Tensor(const py::tuple &input, const TypePtr &data_type) { init(py::array(input), data_type); } | |||
| Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } | |||
| Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } | |||
| Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type) | |||
| : MetaTensor(tensor), device_address_(tensor.device_address_) { | |||
| init(tensor.data_, data_type); | |||
| dirty_ = tensor.is_dirty(); | |||
| id_ = tensor.id(); | |||
| static size_t SizeOf(const std::vector<int> &shape) { | |||
| return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>()); | |||
| } | |||
| Tensor &Tensor::operator=(const Tensor &tensor) { | |||
| if (this != &tensor) { | |||
| MetaTensor::operator=(tensor); | |||
| dirty_ = tensor.is_dirty(); | |||
| device_address_ = tensor.device_address(); | |||
| data_ = tensor.data_; | |||
| id_ = tensor.id(); | |||
| template <typename T> | |||
| std::vector<T> CopyData(const std::vector<int> &shape, void *data, TypeId data_type) { | |||
| const size_t count = SizeOf(shape); | |||
| switch (data_type) { | |||
| case kNumberTypeBool: { | |||
| auto buf = static_cast<Bool *>(data); | |||
| return std::vector<T>(buf, buf + count); | |||
| } | |||
| case kNumberTypeUInt8: { | |||
| auto buf = static_cast<uint8_t *>(data); | |||
| return std::vector<T>(buf, buf + count); | |||
| } | |||
| case kNumberTypeInt8: { | |||
| auto buf = static_cast<int8_t *>(data); | |||
| return std::vector<T>(buf, buf + count); | |||
| } | |||
| case kNumberTypeInt16: { | |||
| auto buf = static_cast<int16_t *>(data); | |||
| return std::vector<T>(buf, buf + count); | |||
| } | |||
| case kNumberTypeInt32: { | |||
| auto buf = static_cast<int32_t *>(data); | |||
| return std::vector<T>(buf, buf + count); | |||
| } | |||
| case kNumberTypeInt64: { | |||
| auto buf = static_cast<int64_t *>(data); | |||
| return std::vector<T>(buf, buf + count); | |||
| } | |||
| case kNumberTypeUInt16: { | |||
| auto buf = static_cast<uint16_t *>(data); | |||
| return std::vector<T>(buf, buf + count); | |||
| } | |||
| case kNumberTypeUInt32: { | |||
| auto buf = static_cast<uint32_t *>(data); | |||
| return std::vector<T>(buf, buf + count); | |||
| } | |||
| case kNumberTypeUInt64: { | |||
| auto buf = static_cast<uint64_t *>(data); | |||
| return std::vector<T>(buf, buf + count); | |||
| } | |||
| case kNumberTypeFloat16: { | |||
| auto buf = static_cast<float16 *>(data); | |||
| return std::vector<T>(buf, buf + count); | |||
| } | |||
| case kNumberTypeFloat32: { | |||
| const float *buf = static_cast<float *>(data); | |||
| return std::vector<T>(buf, buf + count); | |||
| } | |||
| case kNumberTypeFloat64: { | |||
| auto buf = static_cast<double *>(data); | |||
| return std::vector<T>(buf, buf + count); | |||
| } | |||
| default: | |||
| break; | |||
| } | |||
| return *this; | |||
| } | |||
| Tensor &Tensor::AssignValue(const Tensor &tensor) { | |||
| *this = tensor; | |||
| return *this; | |||
| MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; | |||
| } | |||
| bool Tensor::operator==(const Tensor &tensor) const { | |||
| return (MetaTensor::operator==(tensor) && data_ == tensor.data_); | |||
| } | |||
| bool Tensor::ValueEqual(const Tensor &other) const { | |||
| auto equal = [&other, this]() -> bool { | |||
| auto np = py::module::import("numpy"); | |||
| auto equal = np.attr("equal")(data_, other.data_); | |||
| auto all_equal = np.attr("all")(equal); | |||
| return all_equal.cast<bool>(); | |||
| }; | |||
| return (MetaTensor::operator==(other) && (data_.is(other.data_) || equal())); | |||
| // Convert to bool is not allowed. | |||
| template <> | |||
| std::vector<Bool> CopyData<Bool>(const std::vector<int> &shape, void *data, TypeId data_type) { | |||
| MS_LOG(EXCEPTION) << "Cannot convert from " << TypeIdLabel(data_type) << " to " << TypeIdLabel(kNumberTypeBool) | |||
| << "."; | |||
| return {}; | |||
| } | |||
| py::tuple Tensor::GetPyTupleShape() const { | |||
| std::vector<int> shape = this->shape(); | |||
| py::tuple dims(shape.size()); | |||
| for (size_t i = 0; i < dims.size(); ++i) { | |||
| dims[i] = py::int_(shape[i]); | |||
| template <typename T> | |||
| std::vector<T> CopyData(const std::vector<int> &shape, void *data, size_t data_len) { | |||
| size_t size = SizeOf(shape); | |||
| if (size * sizeof(T) != data_len) { | |||
| MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T) | |||
| << " item size " << sizeof(T); | |||
| } | |||
| return dims; | |||
| auto buf = static_cast<T *>(data); | |||
| return {buf, buf + size}; | |||
| } | |||
| int Tensor::DataDim() const { return static_cast<int>(data_.ndim()); } | |||
| // Tensor data implementation. | |||
| template <typename T> | |||
| class TensorDataImpl : public TensorData { | |||
| public: | |||
| explicit TensorDataImpl(const std::vector<int> &shape) : shape_(shape), data_(SizeOf(shape)) {} | |||
| int Tensor::DataSize() const { return static_cast<int>(data_.size()); } | |||
| TensorDataImpl(const std::vector<int> &shape, void *data, size_t data_len) | |||
| : shape_(shape), data_(CopyData<T>(shape, data, data_len)) {} | |||
| py::array Tensor::data() const { return data_; } | |||
| TensorDataImpl(const std::vector<int> &shape, void *data, TypeId data_type) | |||
| : shape_(shape), data_(CopyData<T>(shape, data, data_type)) {} | |||
| int Tensor::data_type_c() const { return static_cast<int>(data_type_); } | |||
| template <typename InputIt> | |||
| TensorDataImpl(const std::vector<int> &shape, InputIt first, InputIt last) : shape_(shape), data_(first, last) {} | |||
| std::vector<int> Tensor::shape_c(void) const { return shape(); } | |||
| template <typename Scalar> | |||
| TensorDataImpl(const std::vector<int> &shape, Scalar scalar) : shape_(shape), data_({static_cast<T>(scalar)}) {} | |||
| void *Tensor::data_c(bool writable) { | |||
| // operand of bit operation should be unsigned int. | |||
| unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_; | |||
| bool is_c_contiguous = (flags != 0) ? true : false; | |||
| if (!is_c_contiguous) { | |||
| py::array data_c; | |||
| init(data_type_, shape_, &data_c); | |||
| DataBuf2Contiguous(data_, &data_c); | |||
| data_ = data_c; | |||
| } | |||
| return data_.request(writable).ptr; | |||
| } | |||
| ssize_t size() const override { return data_.size(); } | |||
| TypeId Tensor::GetDataType(const py::buffer_info &buf) const { | |||
| TypeId data_type = TypeId::kTypeUnknown; | |||
| if (buf.format.compare("e") == 0) { | |||
| data_type = TypeId::kNumberTypeFloat16; | |||
| } else if (buf.format.compare("f") == 0) { | |||
| data_type = TypeId::kNumberTypeFloat32; | |||
| } else if (buf.format.compare("d") == 0) { | |||
| data_type = TypeId::kNumberTypeFloat64; | |||
| } else if (buf.format.compare("B") == 0) { | |||
| data_type = TypeId::kNumberTypeUInt8; | |||
| } else if (buf.format.compare("H") == 0) { | |||
| data_type = TypeId::kNumberTypeUInt16; | |||
| } else if (buf.format.compare("I") == 0) { | |||
| data_type = TypeId::kNumberTypeUInt32; | |||
| } else if (buf.format.compare("L") == 0 || buf.format.compare("Q") == 0) { | |||
| data_type = TypeId::kNumberTypeUInt64; | |||
| } else if (buf.format.compare("b") == 0) { | |||
| data_type = TypeId::kNumberTypeInt8; | |||
| } else if (buf.format.compare("h") == 0) { | |||
| data_type = TypeId::kNumberTypeInt16; | |||
| } else if (buf.format.compare("i") == 0) { | |||
| data_type = TypeId::kNumberTypeInt32; | |||
| } else if (buf.format.compare("l") == 0 || buf.format.compare("q") == 0) { | |||
| data_type = TypeId::kNumberTypeInt64; | |||
| } else if (buf.format.compare("?") == 0) { | |||
| data_type = TypeId::kNumberTypeBool; | |||
| } else { | |||
| MS_LOG(WARNING) << "Get unsupported DataType " << buf.format << "."; | |||
| } | |||
| return data_type; | |||
| } | |||
| ssize_t itemsize() const override { return static_cast<ssize_t>(sizeof(T)); } | |||
| void Tensor::init(const py::array &input, const TypePtr &type_ptr) { | |||
| TypeId data_type = TypeId::kTypeUnknown; | |||
| if (type_ptr != nullptr) { | |||
| data_type = type_ptr->type_id(); | |||
| } | |||
| init(input, data_type); | |||
| } | |||
| ssize_t nbytes() const override { return size() * itemsize(); } | |||
| void Tensor::init(const py::array &input, const TypeId &data_type) { | |||
| py::buffer_info buf = input.request(); | |||
| ssize_t ndim() const override { return static_cast<ssize_t>(shape_.size()); } | |||
| data_type_ = GetDataType(buf); | |||
| if (TypeId::kTypeUnknown == data_type && TypeId::kTypeUnknown == data_type_) { | |||
| MS_LOG(EXCEPTION) << "Unsupported tensor type!"; | |||
| void *data() override { | |||
| static std::vector<T> empty_data(1); | |||
| if (data_.empty()) { | |||
| // Prevent null pointer for empty data. | |||
| return empty_data.data(); | |||
| } | |||
| return data_.data(); | |||
| } | |||
| std::vector<ssize_t> tm = buf.shape; | |||
| size_t len = tm.size(); | |||
| std::vector<int> dims(len); | |||
| for (size_t i = 0; i < len; ++i) { | |||
| dims[i] = static_cast<int>(tm[i]); | |||
| bool equals(const TensorData &other) const override { | |||
| auto ptr = dynamic_cast<const TensorDataImpl<T> *>(&other); | |||
| if (ptr) { | |||
| return (ptr == this) || ((shape_ == ptr->shape_) && (data_ == ptr->data_)); | |||
| } | |||
| return false; | |||
| } | |||
| (void)set_shape(dims); | |||
| if (TypeId::kTypeUnknown != data_type && TypeId::kTypeUnknown != data_type_ && data_type_ != data_type) { | |||
| // If user defined data type is not same as GetDataType from the data | |||
| bool success = convert_data(input, data_type_, &data_, data_type); | |||
| if (success) { | |||
| data_type_ = data_type; | |||
| } else { | |||
| data_type_ = TypeId::kTypeUnknown; | |||
| MS_LOG(EXCEPTION) << "Convert data from " << data_type_ << " to " << data_type << " failed!"; | |||
| std::string ToString() const override { | |||
| std::ostringstream ss; | |||
| ss << '['; | |||
| for (auto value : data_) { | |||
| ss << value << ','; | |||
| } | |||
| } else { | |||
| data_ = input; | |||
| ss << ']'; | |||
| return ss.str(); | |||
| } | |||
| dirty_ = true; | |||
| id_ = std::to_string((uintptr_t)(this)) + std::to_string(count++); | |||
| } | |||
| void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) { | |||
| data_type_ = data_type; | |||
| shape_ = shape; | |||
| private: | |||
| std::vector<int> shape_; | |||
| std::vector<T> data_; | |||
| }; | |||
| template <typename... Args> | |||
| TensorDataPtr MakeTensorData(TypeId data_type, const std::vector<int> &shape, Args... args) { | |||
| switch (data_type) { | |||
| case kNumberTypeBool: | |||
| *data = py::array_t<bool, py::array::c_style>(shape); | |||
| break; | |||
| // std::vector<bool> is a specialization of std::vector, | |||
| // it may use single bit instead of sizeof(bool) bytes, | |||
| // so we use std::vector<Bool> for bool tensors. | |||
| return std::make_shared<TensorDataImpl<Bool>>(shape, args...); | |||
| case kNumberTypeUInt8: | |||
| return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...); | |||
| case kNumberTypeInt8: | |||
| *data = py::array_t<int8_t, py::array::c_style>(shape); | |||
| break; | |||
| return std::make_shared<TensorDataImpl<int8_t>>(shape, args...); | |||
| case kNumberTypeInt16: | |||
| *data = py::array_t<int16_t, py::array::c_style>(shape); | |||
| break; | |||
| return std::make_shared<TensorDataImpl<int16_t>>(shape, args...); | |||
| case kNumberTypeInt32: | |||
| *data = py::array_t<int32_t, py::array::c_style>(shape); | |||
| break; | |||
| return std::make_shared<TensorDataImpl<int32_t>>(shape, args...); | |||
| case kNumberTypeInt64: | |||
| *data = py::array_t<int64_t, py::array::c_style>(shape); | |||
| break; | |||
| case kNumberTypeUInt8: | |||
| *data = py::array_t<uint8_t, py::array::c_style>(shape); | |||
| break; | |||
| return std::make_shared<TensorDataImpl<int64_t>>(shape, args...); | |||
| case kNumberTypeUInt16: | |||
| *data = py::array_t<uint16_t, py::array::c_style>(shape); | |||
| break; | |||
| return std::make_shared<TensorDataImpl<uint16_t>>(shape, args...); | |||
| case kNumberTypeUInt32: | |||
| *data = py::array_t<uint32_t, py::array::c_style>(shape); | |||
| break; | |||
| return std::make_shared<TensorDataImpl<uint32_t>>(shape, args...); | |||
| case kNumberTypeUInt64: | |||
| *data = py::array_t<uint64_t, py::array::c_style>(shape); | |||
| break; | |||
| return std::make_shared<TensorDataImpl<uint64_t>>(shape, args...); | |||
| case kNumberTypeFloat16: | |||
| *data = py::array_t<float16, py::array::c_style>(shape); | |||
| break; | |||
| return std::make_shared<TensorDataImpl<float16>>(shape, args...); | |||
| case kNumberTypeFloat32: | |||
| *data = py::array_t<float, py::array::c_style>(shape); | |||
| break; | |||
| return std::make_shared<TensorDataImpl<float>>(shape, args...); | |||
| case kNumberTypeFloat64: | |||
| *data = py::array_t<double, py::array::c_style>(shape); | |||
| break; | |||
| return std::make_shared<TensorDataImpl<double>>(shape, args...); | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; | |||
| break; | |||
| } | |||
| id_ = std::to_string((uintptr_t)(this)) + std::to_string(count++); | |||
| MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; | |||
| } | |||
| TypePtr Tensor::SetDtype(const TypePtr type_ptr) { | |||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||
| (void)set_data_type(type_ptr->type_id()); | |||
| return type_ptr; | |||
| } | |||
| Tensor::Tensor(const Tensor &tensor) | |||
| : MetaTensor(tensor), | |||
| init_flag_(tensor.init_flag_), | |||
| data_(tensor.data_), | |||
| dirty_(tensor.dirty_), | |||
| id_(tensor.id_), | |||
| device_address_(tensor.device_address_) {} | |||
| TypeId Tensor::set_data_type(const TypeId data_type) { | |||
| if (data_.size() > 0 && data_type_ != data_type) { | |||
| bool success = convert_data(data_, data_type_, &data_, data_type); | |||
| if (success) { | |||
| data_type_ = data_type; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Convert data from " << data_type_ << " to " << data_type << " failed!"; | |||
| } | |||
| } else if (data_.size() == 0) { | |||
| data_type_ = data_type; | |||
| } | |||
| Tensor::Tensor(const Tensor &tensor, TypeId data_type) | |||
| : MetaTensor(data_type, tensor.shape_), | |||
| init_flag_(tensor.init_flag_), | |||
| data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)), | |||
| dirty_(tensor.dirty_), | |||
| id_(tensor.id_), | |||
| device_address_(tensor.device_address_) {} | |||
| return data_type_; | |||
| } | |||
| Tensor::Tensor(TypeId data_type, const std::vector<int> &shape, TensorDataPtr data) | |||
| : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {} | |||
| bool Tensor::is_init() { return init_flag_; } | |||
| Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) | |||
| : Tensor(data_type, shape, MakeTensorData(data_type, shape)) {} | |||
| void Tensor::set_init_flag(bool flag) { init_flag_ = flag; } | |||
| Tensor::Tensor(TypeId data_type, const std::vector<int> &shape, void *data, size_t data_len) | |||
| : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {} | |||
| bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out, | |||
| const TypeId out_data_type) { | |||
| if (out == nullptr) { | |||
| return false; | |||
| } | |||
| Tensor::Tensor(TypeId data_type, const std::vector<int> &shape, void *data, TypeId src_data_type) | |||
| : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {} | |||
| bool result = true; | |||
| if (TypeId::kTypeUnknown == in_data_type || TypeId::kTypeUnknown == out_data_type) { | |||
| result = false; | |||
| } else if (in_data_type == out_data_type) { | |||
| *out = in; | |||
| } else if (TypeId::kNumberTypeFloat64 == out_data_type) { | |||
| *out = in.attr("astype").cast<py::function>()("float64").cast<py::array>(); | |||
| } else if (TypeId::kNumberTypeFloat32 == out_data_type) { | |||
| *out = in.attr("astype").cast<py::function>()("float32").cast<py::array>(); | |||
| } else if (TypeId::kNumberTypeFloat16 == out_data_type) { | |||
| *out = in.attr("astype").cast<py::function>()("float16").cast<py::array>(); | |||
| } else if (TypeId::kNumberTypeInt64 == out_data_type) { | |||
| *out = in.attr("astype").cast<py::function>()("int64").cast<py::array>(); | |||
| } else if (TypeId::kNumberTypeInt32 == out_data_type) { | |||
| *out = in.attr("astype").cast<py::function>()("int32").cast<py::array>(); | |||
| } else if (TypeId::kNumberTypeInt16 == out_data_type) { | |||
| *out = in.attr("astype").cast<py::function>()("int16").cast<py::array>(); | |||
| } else if (TypeId::kNumberTypeInt8 == out_data_type) { | |||
| *out = in.attr("astype").cast<py::function>()("int8").cast<py::array>(); | |||
| } else if (TypeId::kNumberTypeUInt8 == out_data_type) { | |||
| *out = in.attr("astype").cast<py::function>()("uint8").cast<py::array>(); | |||
| } else if (TypeId::kNumberTypeUInt16 == out_data_type) { | |||
| *out = in.attr("astype").cast<py::function>()("uint16").cast<py::array>(); | |||
| } else if (TypeId::kNumberTypeUInt32 == out_data_type) { | |||
| *out = in.attr("astype").cast<py::function>()("uint32").cast<py::array>(); | |||
| } else if (TypeId::kNumberTypeUInt64 == out_data_type) { | |||
| *out = in.attr("astype").cast<py::function>()("uint64").cast<py::array>(); | |||
| } else { | |||
| data_type_ = TypeId::kTypeUnknown; | |||
| MS_LOG(EXCEPTION) << "Cannot convert from " << TypeIdLabel(in_data_type) << " to " << TypeIdLabel(out_data_type) | |||
| << "."; | |||
| } | |||
| Tensor::Tensor(const std::vector<int64_t> &input, const TypePtr &data_type) | |||
| : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast<int>(input.size())}), | |||
| data_(MakeTensorData(data_type_, shape_, input.begin(), input.end())), | |||
| id_(MakeId()) {} | |||
| Tensor::Tensor(const std::vector<double> &input, const TypePtr &data_type) | |||
| : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {static_cast<int>(input.size())}), | |||
| data_(MakeTensorData(data_type_, shape_, input.begin(), input.end())), | |||
| id_(MakeId()) {} | |||
| Tensor::Tensor(int64_t input, const TypePtr &data_type) | |||
| : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {}), | |||
| data_(MakeTensorData(data_type_, {}, input)), | |||
| id_(MakeId()) {} | |||
| return result; | |||
| Tensor::Tensor(double input, const TypePtr &data_type) | |||
| : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {}), | |||
| data_(MakeTensorData(data_type_, {}, input)), | |||
| id_(MakeId()) {} | |||
| bool Tensor::operator==(const Tensor &tensor) const { | |||
| return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_)); | |||
| } | |||
| bool Tensor::ValueEqual(const Tensor &tensor) const { | |||
| return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_))); | |||
| } | |||
| abstract::AbstractBasePtr Tensor::ToAbstract() { | |||
| @@ -355,7 +297,7 @@ std::string Tensor::ToString() const { | |||
| buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); | |||
| // only print small tensor | |||
| if (DataSize() < small_tensor_size) { | |||
| buf << "val:" << std::string(py::str(data())); | |||
| buf << "val:" << data().ToString(); | |||
| } | |||
| return buf.str(); | |||
| } | |||
| @@ -365,164 +307,25 @@ std::string Tensor::ToStringRepr() const { | |||
| auto type_ptr = this->Dtype(); | |||
| MS_EXCEPTION_IF_NULL(type_ptr); | |||
| buf << "Tensor shape:[" << shape() << "]" << type_ptr->ToString(); | |||
| buf << "\nval:" << std::string(py::str(data())); | |||
| buf << "\nval:" << data().ToString(); | |||
| return buf.str(); | |||
| } | |||
| py::array Tensor::data_sync() { | |||
| void Tensor::data_sync() const { | |||
| if (device_address_ != nullptr) { | |||
| if (!device_address_->SyncDeviceToHost(this->shape(), static_cast<size_t>(this->data().nbytes()), this->data_type(), | |||
| this->data_c(true))) { | |||
| if (!device_address_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) { | |||
| MS_LOG(EXCEPTION) << "SyncDeviceToHost when asnumpy."; | |||
| } | |||
| } | |||
| return data_; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||
| // dtype should define before Tensor, because Tensor init depend dtype | |||
| (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor") | |||
| .def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape")) | |||
| .def(py::init<py::array, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr) | |||
| .def(py::init<py::float_, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr) | |||
| .def(py::init<py::int_, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr) | |||
| .def(py::init<py::list, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr) | |||
| .def(py::init<py::tuple, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr) | |||
| .def(py::init<Tensor, TypePtr>(), py::arg("input"), py::arg("dtype") = nullptr) | |||
| .def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_) | |||
| .def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter( | |||
| Get the tensor's data type. | |||
| Returns: | |||
| type, the data type of tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) | |||
| >>> data.dtype | |||
| Int32 | |||
| )mydelimiter") | |||
| .def_property_readonly("shape", &Tensor::GetPyTupleShape, R"mydelimiter( | |||
| Get the tensor's shape. | |||
| Returns: | |||
| tuple[int], the shape of tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((3, 3))) | |||
| >>> data.shape() | |||
| (3, 3) | |||
| )mydelimiter") | |||
| .def("asnumpy", &Tensor::data_sync, R"mydelimiter( | |||
| Convert tensor to numpy.ndarray. | |||
| Returns: | |||
| numpy.ndarray. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> array = data.asnumpy() | |||
| >>> array | |||
| array([[1., 1., 1.], | |||
| [1., 1., 1.]]) | |||
| )mydelimiter") | |||
| .def("size", &Tensor::DataSize, R"mydelimiter( | |||
| Get tensor's data size. | |||
| Returns: | |||
| int, the size of tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> data.size() | |||
| 6 | |||
| )mydelimiter") | |||
| .def("is_init", &Tensor::is_init, R"mydelimiter( | |||
| Get tensor init_flag. | |||
| Returns: | |||
| bool, whether the tensor init. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> data.is_init() | |||
| False | |||
| )mydelimiter") | |||
| .def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter( | |||
| Set tensor init_flag. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> data.set_init_flag(True) | |||
| )mydelimiter") | |||
| .def("dim", &Tensor::DataDim, R"mydelimiter( | |||
| Get tensor's data dimension. | |||
| Returns: | |||
| int, the dimension of tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> data.dim() | |||
| 2 | |||
| )mydelimiter") | |||
| .def("set_dtype", &Tensor::SetDtype, R"mydelimiter( | |||
| Set the tensor's data type. | |||
| Arg: | |||
| dtype (:class:`mindspore.dtype`): The type of output tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) | |||
| >>> data.set_dtype(mindspore.int32) | |||
| mindspore.int32 | |||
| )mydelimiter") | |||
| .def("assign_value", &Tensor::AssignValue, R"mydelimiter( | |||
| Assign another tensor value to this. | |||
| Arg: | |||
| value (:class:`mindspore.tensor`): The value tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) | |||
| >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32)) | |||
| >>> data.assign_value(data2) | |||
| >>> data.shape | |||
| (2, 2) | |||
| )mydelimiter") | |||
| .def("__str__", &Tensor::ToString) | |||
| .def("__repr__", &Tensor::ToStringRepr) | |||
| .def(py::pickle( | |||
| [](const Tensor &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(t.data()); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| Tensor tensor(t[0].cast<py::array>()); | |||
| return tensor; | |||
| })); | |||
| (void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor") | |||
| .def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape")) | |||
| .def(py::pickle( | |||
| [](const MetaTensor &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(static_cast<int>(t.data_type()), t.shape()); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 2) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<std::vector<int>>()); | |||
| return tensor; | |||
| })) | |||
| .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) | |||
| .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | |||
| .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape."); | |||
| })); | |||
| TypeId Tensor::set_data_type(const TypeId data_type) { | |||
| if (data_type != data_type_) { | |||
| data_ = MakeTensorData(data_type, shape_, data_->data(), data_type_); | |||
| return MetaTensor::set_data_type(data_type); | |||
| } | |||
| return data_type; | |||
| } | |||
| } // namespace tensor | |||
| namespace inference { | |||
| @@ -530,8 +333,6 @@ MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector<int> &shape | |||
| return new Tensor(data_type, shape); | |||
| } | |||
| Tensor::Tensor() { this->tensor_impl_ = std::make_shared<tensor::Tensor>(); } | |||
| Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) { | |||
| this->tensor_impl_ = std::make_shared<tensor::Tensor>(data_type, shape); | |||
| } | |||
| @@ -585,7 +386,8 @@ size_t Tensor::Size() const { | |||
| void *Tensor::MutableData() const { | |||
| MS_ASSERT(this->tensor_impl_ != nullptr); | |||
| return this->tensor_impl_->data_c(true); | |||
| return this->tensor_impl_->data_c(); | |||
| } | |||
| } // namespace inference | |||
| } // namespace mindspore | |||
| @@ -20,9 +20,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "pybind11/numpy.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include <numeric> | |||
| #include "Eigen/Core" | |||
| #include "device/device_address.h" | |||
| @@ -30,63 +28,8 @@ | |||
| #include "include/ms_tensor.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace py = pybind11; | |||
| using float16 = Eigen::half; | |||
| namespace pybind11 { | |||
| namespace detail { | |||
| // Similar to enums in `pybind11/numpy.h`. Determined by doing: | |||
| // python3 -c 'import numpy as np; print(np.dtype(np.float16).num)' | |||
| constexpr int NPY_FLOAT16 = 23; | |||
| template <typename T> | |||
| struct npy_scalar_caster { | |||
| PYBIND11_TYPE_CASTER(T, _("PleaseOverride")); | |||
| using Array = array_t<T>; | |||
| bool load(handle src, bool convert) { | |||
| // Taken from Eigen casters. Permits either scalar dtype or scalar array. | |||
| handle type = dtype::of<T>().attr("type"); | |||
| if (!convert && !isinstance<Array>(src) && !isinstance(src, type)) return false; | |||
| Array tmp = Array::ensure(src); | |||
| if (tmp && tmp.size() == 1 && tmp.ndim() == 0) { | |||
| this->value = *tmp.data(); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| static handle cast(T src, return_value_policy, handle) { | |||
| Array tmp({1}); | |||
| tmp.mutable_at(0) = src; | |||
| tmp.resize({}); | |||
| // You could also just return the array if you want a scalar array. | |||
| object scalar = tmp[tuple()]; | |||
| return scalar.release(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct npy_format_descriptor<float16> { | |||
| static constexpr auto name = "float16"; | |||
| static pybind11::dtype dtype() { | |||
| handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16); | |||
| return reinterpret_borrow<pybind11::dtype>(ptr); | |||
| } | |||
| virtual ~npy_format_descriptor<float16>() {} | |||
| }; | |||
| template <> | |||
| struct type_caster<float16> : public npy_scalar_caster<float16> { | |||
| static constexpr auto name = "float16"; | |||
| }; | |||
| } // namespace detail | |||
| } // namespace pybind11 | |||
| using mindspore::device::DeviceAddress; | |||
| using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>; | |||
| // brief mindspore namespace. | |||
| @@ -98,179 +41,195 @@ namespace mindspore { | |||
| // | |||
| // A sub namespace in ME to support tensor related definition. | |||
| namespace tensor { | |||
| // Tensor data interface. | |||
| class TensorData { | |||
| public: | |||
| /// Total number of elements. | |||
| virtual ssize_t size() const = 0; | |||
| /// Byte size of a single element. | |||
| virtual ssize_t itemsize() const = 0; | |||
| /// Total number of bytes. | |||
| virtual ssize_t nbytes() const = 0; | |||
| /// Number of dimensions. | |||
| virtual ssize_t ndim() const = 0; | |||
| /// Data pointer. | |||
| virtual void *data() = 0; | |||
| /// Is data equals. | |||
| virtual bool equals(const TensorData &other) const = 0; | |||
| /// To string. | |||
| virtual std::string ToString() const = 0; | |||
| }; | |||
| using TensorDataPtr = std::shared_ptr<TensorData>; | |||
| // Tensor entity class | |||
| class Tensor : public MetaTensor { | |||
| public: | |||
| Tensor() = default; | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| // brief Constructor for Python. | |||
| // brief Create tensor from another tensor, data is shared. | |||
| // | |||
| // param tensor [Tensor] The input tensor. | |||
| explicit Tensor(const Tensor &tensor); | |||
| // brief Create tensor with given data type from another tensor. | |||
| // | |||
| // param type_ptr [TypePty] Data type of the tensor. | |||
| // param py_shape [py::tuple] The shape represented by py::tuple of the tensor. | |||
| Tensor(const TypePtr &type_ptr, const py::tuple &shape); | |||
| // param tensor [Tensor] The input tensor. | |||
| // param data_type [TypeId] The new tensor data type. | |||
| Tensor(const Tensor &tensor, TypeId data_type); | |||
| // brief Constructor for C++. | |||
| // brief Create tensor with the given shared tensor data. | |||
| // | |||
| // param data_type [TypeId] Data type of the tensor. | |||
| // param shape The shape represented by std::vector<int> of the tensor. | |||
| // param data The shared tensor data. | |||
| Tensor(TypeId data_type, const std::vector<int> &shape, TensorDataPtr data); | |||
| // brief Create an all zero tensor. | |||
| // | |||
| // param data_type [TypeId] Data type of the tensor. | |||
| // param shape The shape represented by std::vector<int> of the tensor. | |||
| Tensor(TypeId data_type, const std::vector<int> &shape); | |||
| // brief Constructor for Python. | |||
| // brief Create a tensor with input data buffer. | |||
| // | |||
| // param input [py::array] Data value of the tensor. | |||
| // param data_type [TypeId] Data type of the tensor. | |||
| explicit Tensor(const py::array &input, const TypePtr &data_type = nullptr); | |||
| // param shape The shape represented by std::vector<int> of the tensor. | |||
| // param data The input data to be copied into tensor. | |||
| // param data_len The length of data in bytes. | |||
| Tensor(TypeId data_type, const std::vector<int> &shape, void *data, size_t data_len); | |||
| // brief Constructor | |||
| // brief Create a tensor with input data buffer and given source data type. | |||
| // | |||
| // param input [py::list] the data for tensor | |||
| // param data_type [TypeId] data type | |||
| explicit Tensor(const py::list &input, const TypePtr &data_type = nullptr); | |||
| // param data_type [TypeId] Data type of the tensor. | |||
| // param shape The shape represented by std::vector<int> of the tensor. | |||
| // param data The input data to be copied into tensor. | |||
| // param src_data_type The source data type. | |||
| Tensor(TypeId data_type, const std::vector<int> &shape, void *data, TypeId src_data_type); | |||
| // brief Constructor | |||
| // brief Create 1 dimension tensor from an int vector. | |||
| // | |||
| // param input [py::tuple] the data for tensor | |||
| // param input [std::vector<int64_t>] the data for tensor | |||
| // param data_type [TypeId] data type | |||
| explicit Tensor(const py::tuple &input, const TypePtr &data_type = nullptr); | |||
| explicit Tensor(const std::vector<int64_t> &input, const TypePtr &data_type = nullptr); | |||
| // brief Constructor | |||
| // brief Create 1 dimension tensor from a float vector. | |||
| // | |||
| // param input [py::float_] the data for tensor | |||
| // param input [std::vector<double>] the data for tensor | |||
| // param data_type [TypeId] data type | |||
| explicit Tensor(const py::float_ &input, const TypePtr &data_type = nullptr); | |||
| explicit Tensor(const std::vector<double> &input, const TypePtr &data_type = nullptr); | |||
| // brief Constructor | |||
| // brief Create 0 dimension tensor from an int scalar. | |||
| // | |||
| // param input [py::int_] the data for tensor | |||
| // param input [int64] the data for tensor | |||
| // param data_type [TypeId] data type | |||
| explicit Tensor(const py::int_ &input, const TypePtr &data_type = nullptr); | |||
| explicit Tensor(int64_t input, const TypePtr &data_type = nullptr); | |||
| // brief Constructor | |||
| // brief Create 0 dimension tensor from a float scalar. | |||
| // | |||
| // param input [Tensor] the data for tensor | |||
| // param input [double] the data for tensor | |||
| // param data_type [TypeId] data type | |||
| Tensor(const Tensor &tensor, const TypePtr &data_type = nullptr); | |||
| explicit Tensor(double input, const TypePtr &data_type = nullptr); | |||
| ~Tensor() override = default; | |||
| MS_DECLARE_PARENT(Tensor, MetaTensor); | |||
| // brief Overloads operator = for Tensor. | |||
| // | |||
| // The constructed Tensor object has the same type and shape with tensor. | |||
| // | |||
| // param tensor An existing Tensor object. | |||
| Tensor &operator=(const Tensor &tensor); | |||
| // brief Compares two Tensor objects. | |||
| // | |||
| // Compare two tensor objects to see if they have same data type, shape and | |||
| // data value. | |||
| // Compare two tensor objects to see if they have same data type, shape and data address. | |||
| // | |||
| // param tensor The Tensor object to be compared. | |||
| // return true: If having same type, shape and data, return true, or return false. | |||
| // return true: If having same type, shape and data address, return true, or return false. | |||
| bool operator==(const Tensor &tensor) const; | |||
| // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. | |||
| bool ValueEqual(const Tensor &other) const; | |||
| // assgin value to this tensor | |||
| Tensor &AssignValue(const Tensor &tensor); | |||
| // It is different from 'operator==' which just compare shape/type/address, | |||
| // it do real value comparison. | |||
| bool ValueEqual(const Tensor &tensor) const; | |||
| bool operator==(const Value &other) const override { | |||
| if (other.isa<Tensor>()) { | |||
| auto other_ = static_cast<const Tensor &>(other); | |||
| auto &other_ = static_cast<const Tensor &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| return false; | |||
| } | |||
| py::tuple GetPyTupleShape() const; | |||
| // brief Gets tensor's dimension | |||
| // | |||
| // return The number of dimensions of the tensor data. | |||
| int DataDim() const; | |||
| int DataDim() const { return static_cast<int>(data().ndim()); } | |||
| // brief Getting tensor data size | |||
| // | |||
| // return The total number of elements of the tensor data. | |||
| int DataSize() const; | |||
| // brief Tensor's data value. | |||
| // | |||
| // return [py::array] The tensor's data in py::array. | |||
| py::array data() const; | |||
| int DataSize() const { return static_cast<int>(data().size()); } | |||
| // brief Get the data type fo the tensor for C++ | |||
| // | |||
| // return [int] The tensor's data type will be cast to int to return. | |||
| int data_type_c() const; | |||
| int data_type_c() const { return static_cast<int>(data_type_); } | |||
| // brief Get the tensor's shape for C++ | |||
| // | |||
| // return [std::vector<int>] | |||
| std::vector<int> shape_c(void) const; | |||
| std::vector<int> shape_c(void) const { return shape(); } | |||
| // brief Get Tensor data pointer for c++ type | |||
| // | |||
| // param writable true if writable, false if read only | |||
| // return The pointer to the object | |||
| void *data_c(bool writable = false); | |||
| void *data_c() { return data().data(); } | |||
| // brief Get Tensor data byte-size for c++ type | |||
| // | |||
| // return byte size of Tensor data | |||
| size_t Size() const { return this->data().nbytes(); } | |||
| size_t Size() const { return data().nbytes(); } | |||
| // brief Get data type from tensor data. | |||
| void *data_c() const { return data_->data(); } | |||
| // brief Sync data with device. | |||
| void data_sync() const; | |||
| // brief Get the internal data object. | |||
| // | |||
| // param buf The buffer info of the py::array data. | |||
| // return The [TypeId] of the tensor data. | |||
| TypeId GetDataType(const py::buffer_info &buf) const; | |||
| // return The reference to internal data object. | |||
| TensorData &data() { return *data_; } | |||
| // brief Sets the data type of a tensor. | |||
| // brief Get the internal data shared pointer. | |||
| // | |||
| // param data_type The data type of the tensor to be set. | |||
| // return The reference to internal data object. | |||
| const TensorDataPtr &data_ptr() const { return data_; } | |||
| // brief Get the internal data object. | |||
| // | |||
| // return The reference to internal data object. | |||
| const TensorData &data() const { return *data_; } | |||
| TypeId set_data_type(const TypeId data_type) override; | |||
| TypePtr SetDtype(const TypePtr type_ptr) override; | |||
| std::string GetShapeAndDataTypeInfo() const; | |||
| std::string ToString() const override; | |||
| std::string ToStringRepr() const; | |||
| py::array data_; // < Tensor's data value | |||
| const bool parse_info_ = true; | |||
| bool is_init(); | |||
| void set_init_flag(bool flag); | |||
| private: | |||
| // brief init tensor | |||
| // | |||
| // param input [py::array] the data for tensor | |||
| // param data_type [TypeId] data type | |||
| // return true if succeed, false if failed. | |||
| void init(const py::array &input, const TypeId &data_type); | |||
| void init(const py::array &input, const TypePtr &type_ptr); | |||
| bool init_flag_{false}; | |||
| // brief init tensor attribute | |||
| // | |||
| // param data_type [TypeId] Data type of the tensor. | |||
| // param shape [py::array] The shape of the tensor. | |||
| // return true if succeed, false if failed. | |||
| void init(TypeId data_type, const std::vector<int> &shape, py::array *data); | |||
| std::string ToStringRepr() const; | |||
| bool convert_data(const py::array &in, const TypeId in_data_type, py::array *out, const TypeId out_data_type); | |||
| bool is_init() { return init_flag_; } | |||
| void set_init_flag(bool flag) { init_flag_ = flag; } | |||
| public: | |||
| bool is_dirty() const { return dirty_; } | |||
| void set_dirty(const bool dirty) { dirty_ = dirty; } | |||
| DeviceAddressPtr device_address() const { return device_address_; } | |||
| void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } | |||
| py::array data_sync(); | |||
| std::string id() const { return id_; } | |||
| const bool parse_info_ = true; | |||
| private: | |||
| bool init_flag_{false}; | |||
| TensorDataPtr data_{nullptr}; | |||
| bool dirty_{true}; | |||
| std::string id_{""}; | |||
| DeviceAddressPtr device_address_{nullptr}; | |||
| @@ -282,8 +241,6 @@ using TensorPtrList = std::vector<std::shared_ptr<Tensor>>; | |||
| namespace inference { | |||
| class Tensor : public MSTensor { | |||
| public: | |||
| Tensor(); | |||
| Tensor(TypeId data_type, const std::vector<int> &shape); | |||
| explicit Tensor(std::shared_ptr<tensor::Tensor> tensor_ptr); | |||
| @@ -0,0 +1,377 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ir/tensor_py.h" | |||
| #include <functional> | |||
| #include <numeric> | |||
| #include <vector> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include "device/device_address.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| namespace mindspore { | |||
| namespace tensor { | |||
| static TypeId GetDataType(const py::buffer_info &buf) { | |||
| if (buf.format.size() == 1) { | |||
| switch (buf.format.front()) { | |||
| case 'e': | |||
| case 'f': | |||
| case 'd': | |||
| switch (buf.itemsize) { | |||
| case 2: | |||
| return TypeId::kNumberTypeFloat16; | |||
| case 4: | |||
| return TypeId::kNumberTypeFloat32; | |||
| case 8: | |||
| return TypeId::kNumberTypeFloat64; | |||
| } | |||
| break; | |||
| case 'b': | |||
| case 'h': | |||
| case 'i': | |||
| case 'l': | |||
| case 'q': | |||
| switch (buf.itemsize) { | |||
| case 1: | |||
| return TypeId::kNumberTypeInt8; | |||
| case 2: | |||
| return TypeId::kNumberTypeInt16; | |||
| case 4: | |||
| return TypeId::kNumberTypeInt32; | |||
| case 8: | |||
| return TypeId::kNumberTypeInt64; | |||
| } | |||
| break; | |||
| case 'B': | |||
| case 'H': | |||
| case 'I': | |||
| case 'L': | |||
| case 'Q': | |||
| switch (buf.itemsize) { | |||
| case 1: | |||
| return TypeId::kNumberTypeUInt8; | |||
| case 2: | |||
| return TypeId::kNumberTypeUInt16; | |||
| case 4: | |||
| return TypeId::kNumberTypeUInt32; | |||
| case 8: | |||
| return TypeId::kNumberTypeUInt64; | |||
| } | |||
| break; | |||
| case '?': | |||
| return TypeId::kNumberTypeBool; | |||
| } | |||
| } | |||
| MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize; | |||
| return TypeId::kTypeUnknown; | |||
| } | |||
| static std::string GetPyTypeFormat(TypeId data_type) { | |||
| switch (data_type) { | |||
| case TypeId::kNumberTypeFloat16: | |||
| return "e"; | |||
| case TypeId::kNumberTypeFloat32: | |||
| return py::format_descriptor<float>::format(); | |||
| case TypeId::kNumberTypeFloat64: | |||
| return py::format_descriptor<double>::format(); | |||
| case TypeId::kNumberTypeUInt8: | |||
| return py::format_descriptor<uint8_t>::format(); | |||
| case TypeId::kNumberTypeUInt16: | |||
| return py::format_descriptor<uint16_t>::format(); | |||
| case TypeId::kNumberTypeUInt32: | |||
| return py::format_descriptor<uint32_t>::format(); | |||
| case TypeId::kNumberTypeUInt64: | |||
| return py::format_descriptor<uint64_t>::format(); | |||
| case TypeId::kNumberTypeInt8: | |||
| return py::format_descriptor<int8_t>::format(); | |||
| case TypeId::kNumberTypeInt16: | |||
| return py::format_descriptor<int16_t>::format(); | |||
| case TypeId::kNumberTypeInt32: | |||
| return py::format_descriptor<int32_t>::format(); | |||
| case TypeId::kNumberTypeInt64: | |||
| return py::format_descriptor<int64_t>::format(); | |||
| case TypeId::kNumberTypeBool: | |||
| return py::format_descriptor<bool>::format(); | |||
| default: | |||
| MS_LOG(WARNING) << "Unsupported DataType " << data_type << "."; | |||
| return ""; | |||
| } | |||
| } | |||
| static bool IsCContiguous(const py::array &input) { | |||
| auto flags = static_cast<unsigned int>(input.flags()); | |||
| return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0; | |||
| } | |||
| TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) { | |||
| // Get input buffer info. | |||
| py::buffer_info buf = input.request(); | |||
| // Check data types. | |||
| auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown; | |||
| auto buf_type = GetDataType(buf); | |||
| if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) { | |||
| MS_LOG(EXCEPTION) << "Unsupported tensor type!"; | |||
| } | |||
| // Use buf type as data type if type_ptr not set. | |||
| if (data_type == TypeId::kTypeUnknown) { | |||
| data_type = buf_type; | |||
| } | |||
| // Convert input array to C contiguous if need. | |||
| std::unique_ptr<char[]> tmp_buf; | |||
| if (!IsCContiguous(input)) { | |||
| Py_buffer pybuf; | |||
| if (PyObject_GetBuffer(input.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS)) { | |||
| MS_LOG(EXCEPTION) << "Failed to get buffer from the input!"; | |||
| } | |||
| tmp_buf = std::make_unique<char[]>(pybuf.len); | |||
| if (PyBuffer_ToContiguous(tmp_buf.get(), &pybuf, pybuf.len, 'C')) { | |||
| MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer."; | |||
| } | |||
| PyBuffer_Release(&pybuf); | |||
| buf.ptr = tmp_buf.get(); | |||
| } | |||
| // Get tensor shape. | |||
| std::vector<int> shape(buf.shape.begin(), buf.shape.end()); | |||
| if (data_type == buf_type) { | |||
| // Use memory copy if input data type is same as the required type. | |||
| return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf.size * buf.itemsize); | |||
| } | |||
| // Create tensor with data type converted. | |||
| return std::make_shared<Tensor>(data_type, shape, buf.ptr, buf_type); | |||
| } | |||
| static std::vector<ssize_t> GetStrides(const std::vector<ssize_t> &shape, ssize_t item_size) { | |||
| std::vector<ssize_t> strides; | |||
| strides.reserve(shape.size()); | |||
| const auto ndim = shape.size(); | |||
| for (size_t i = 0; i < ndim; ++i) { | |||
| auto stride = item_size; | |||
| for (size_t j = i + 1; j < ndim; ++j) { | |||
| stride *= shape[j]; | |||
| } | |||
| strides.push_back(stride); | |||
| } | |||
| return strides; | |||
| } | |||
| static py::buffer_info GetPyBufferInfo(const Tensor &tensor) { | |||
| std::vector<ssize_t> shape(tensor.shape().begin(), tensor.shape().end()); | |||
| std::vector<ssize_t> strides = GetStrides(shape, tensor.data().itemsize()); | |||
| return py::buffer_info{ | |||
| tensor.data_c(), tensor.data().itemsize(), GetPyTypeFormat(tensor.data_type()), tensor.DataDim(), shape, strides}; | |||
| } | |||
| py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) { | |||
| auto &shape = tensor.shape(); | |||
| py::tuple dims(shape.size()); | |||
| for (size_t i = 0; i < dims.size(); ++i) { | |||
| dims[i] = py::int_(shape[i]); | |||
| } | |||
| return dims; | |||
| } | |||
| py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { | |||
| tensor.data_sync(); | |||
| auto info = GetPyBufferInfo(tensor); | |||
| py::object self = py::cast(&tensor); | |||
| return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); | |||
| } | |||
| py::array TensorPy::AsNumpy(const Tensor &tensor) { | |||
| auto info = GetPyBufferInfo(tensor); | |||
| py::object self = py::cast(&tensor); | |||
| return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); | |||
| } | |||
| static std::vector<int> GetShapeFromTuple(const py::tuple &tuple) { | |||
| std::vector<int> shape; | |||
| const size_t size = tuple.size(); | |||
| shape.reserve(tuple.size()); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| shape.push_back(py::int_(tuple[i])); | |||
| } | |||
| return shape; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||
| // Define python Tensor class. | |||
| // dtype should define before Tensor, because Tensor init depend dtype | |||
| (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor") | |||
| .def(py::init([](const Tensor &tensor) { return std::make_shared<Tensor>(tensor); }), | |||
| py::arg("input")) | |||
| .def(py::init([](const Tensor &tensor, const TypePtr &type_ptr) { | |||
| TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown; | |||
| if (data_type == kTypeUnknown || tensor.data_type() == data_type) { | |||
| return std::make_shared<Tensor>(tensor); | |||
| } | |||
| return std::make_shared<Tensor>(tensor, data_type); | |||
| }), | |||
| py::arg("input"), py::arg("dtype")) | |||
| .def(py::init([](const TypePtr &type_ptr, const py::tuple &shape) { | |||
| auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64; | |||
| return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape)); | |||
| }), | |||
| py::arg("dtype"), py::arg("shape")) | |||
| .def(py::init([](const py::array &input, const TypePtr &type_ptr) { | |||
| return TensorPy::MakeTensor(input, type_ptr); | |||
| }), | |||
| py::arg("input"), py::arg("dtype") = nullptr) | |||
| .def(py::init([](py::float_ input, const TypePtr &type_ptr) { | |||
| return TensorPy::MakeTensor(py::array(input), type_ptr); | |||
| }), | |||
| py::arg("input"), py::arg("dtype") = nullptr) | |||
| .def(py::init([](py::int_ input, const TypePtr &type_ptr) { | |||
| return TensorPy::MakeTensor(py::array(input), type_ptr); | |||
| }), | |||
| py::arg("input"), py::arg("dtype") = nullptr) | |||
| .def(py::init([](py::list input, const TypePtr &type_ptr) { | |||
| return TensorPy::MakeTensor(py::array(input), type_ptr); | |||
| }), | |||
| py::arg("input"), py::arg("dtype") = nullptr) | |||
| .def(py::init([](py::tuple input, const TypePtr &type_ptr) { | |||
| return TensorPy::MakeTensor(py::array(input), type_ptr); | |||
| }), | |||
| py::arg("input"), py::arg("dtype") = nullptr) | |||
| .def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_) | |||
| .def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter( | |||
| Get the tensor's data type. | |||
| Returns: | |||
| type, the data type of tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) | |||
| >>> data.dtype | |||
| Int32 | |||
| )mydelimiter") | |||
| .def_property_readonly("shape", TensorPy::GetPyTupleShape, R"mydelimiter( | |||
| Get the tensor's shape. | |||
| Returns: | |||
| tuple[int], the shape of tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((3, 3))) | |||
| >>> data.shape() | |||
| (3, 3) | |||
| )mydelimiter") | |||
| .def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter( | |||
| Convert tensor to numpy.ndarray. | |||
| Returns: | |||
| numpy.ndarray. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> array = data.asnumpy() | |||
| >>> array | |||
| array([[1., 1., 1.], | |||
| [1., 1., 1.]]) | |||
| )mydelimiter") | |||
| .def("size", &Tensor::DataSize, R"mydelimiter( | |||
| Get tensor's data size. | |||
| Returns: | |||
| int, the size of tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> data.size() | |||
| 6 | |||
| )mydelimiter") | |||
| .def("is_init", &Tensor::is_init, R"mydelimiter( | |||
| Get tensor init_flag. | |||
| Returns: | |||
| bool, whether the tensor init. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> data.is_init() | |||
| False | |||
| )mydelimiter") | |||
| .def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter( | |||
| Set tensor init_flag. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> data.set_init_flag(True) | |||
| )mydelimiter") | |||
| .def("dim", &Tensor::DataDim, R"mydelimiter( | |||
| Get tensor's data dimension. | |||
| Returns: | |||
| int, the dimension of tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> data.dim() | |||
| 2 | |||
| )mydelimiter") | |||
| .def("set_dtype", &Tensor::SetDtype, R"mydelimiter( | |||
| Set the tensor's data type. | |||
| Arg: | |||
| dtype (:class:`mindspore.dtype`): The type of output tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) | |||
| >>> data.set_dtype(mindspore.int32) | |||
| mindspore.int32 | |||
| )mydelimiter") | |||
| .def("__str__", &Tensor::ToString) | |||
| .def("__repr__", &Tensor::ToStringRepr) | |||
| .def(py::pickle( | |||
| [](const Tensor &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(TensorPy::AsNumpy(t)); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| return TensorPy::MakeTensor(t[0].cast<py::array>()); | |||
| })); | |||
| // Define python MetaTensor class. | |||
| (void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor") | |||
| .def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape")) | |||
| .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) | |||
| .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | |||
| .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") | |||
| .def(py::pickle( | |||
| [](const MetaTensor &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(static_cast<int>(t.data_type()), t.shape()); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 2) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<std::vector<int>>()); | |||
| return tensor; | |||
| })); | |||
| })); | |||
| } // namespace tensor | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,114 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_IR_TENSOR_PY_H_ | |||
| #define MINDSPORE_CCSRC_IR_TENSOR_PY_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/numpy.h" | |||
| #include "ir/tensor.h" | |||
| namespace py = pybind11; | |||
| namespace pybind11 { | |||
| namespace detail { | |||
| // Similar to enums in `pybind11/numpy.h`. Determined by doing: | |||
| // python3 -c 'import numpy as np; print(np.dtype(np.float16).num)' | |||
| constexpr int NPY_FLOAT16 = 23; | |||
| template <typename T> | |||
| struct npy_scalar_caster { | |||
| PYBIND11_TYPE_CASTER(T, _("PleaseOverride")); | |||
| using Array = array_t<T>; | |||
| bool load(handle src, bool convert) { | |||
| // Taken from Eigen casters. Permits either scalar dtype or scalar array. | |||
| handle type = dtype::of<T>().attr("type"); | |||
| if (!convert && !isinstance<Array>(src) && !isinstance(src, type)) return false; | |||
| Array tmp = Array::ensure(src); | |||
| if (tmp && tmp.size() == 1 && tmp.ndim() == 0) { | |||
| this->value = *tmp.data(); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| static handle cast(T src, return_value_policy, handle) { | |||
| Array tmp({1}); | |||
| tmp.mutable_at(0) = src; | |||
| tmp.resize({}); | |||
| // You could also just return the array if you want a scalar array. | |||
| object scalar = tmp[tuple()]; | |||
| return scalar.release(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct npy_format_descriptor<float16> { | |||
| static constexpr auto name = "float16"; | |||
| static pybind11::dtype dtype() { | |||
| handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16); | |||
| return reinterpret_borrow<pybind11::dtype>(ptr); | |||
| } | |||
| virtual ~npy_format_descriptor<float16>() {} | |||
| }; | |||
| template <> | |||
| struct type_caster<float16> : public npy_scalar_caster<float16> { | |||
| static constexpr auto name = "float16"; | |||
| }; | |||
| } // namespace detail | |||
| } // namespace pybind11 | |||
| using mindspore::device::DeviceAddress; | |||
| using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>; | |||
| // brief mindspore namespace. | |||
| // | |||
| // mindspore namespace is the top level namespace of Mindsporeession project. | |||
| // Other namespace should be a sub namespace of mindspore namespace in the ME project. | |||
| namespace mindspore { | |||
| // brief mindspore::tensor namespace | |||
| // | |||
| // A sub namespace in ME to support tensor related definition. | |||
| namespace tensor { | |||
| // Tensor python wrapper and adapter class. | |||
| class TensorPy { | |||
| public: | |||
| // brief Create Tensor from a numpy array object. | |||
| // | |||
| // param input [py::array] Data value of the tensor. | |||
| // param data_type [TypeId] Data type of the tensor. | |||
| static TensorPtr MakeTensor(const py::array &input, const TypePtr &data_type = nullptr); | |||
| static py::array SyncAsNumpy(const Tensor &tensor); | |||
| static py::array AsNumpy(const Tensor &tensor); | |||
| static py::tuple GetPyTupleShape(const Tensor &tensor); | |||
| }; | |||
| } // namespace tensor | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_TENSOR_PY_H_ | |||
| @@ -23,6 +23,7 @@ | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include "ir/tensor_py.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "debug/anf_ir_utils.h" | |||
| #include "operator/ops.h" | |||
| @@ -257,7 +258,7 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::Att | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||
| onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); | |||
| auto data = value->cast<tensor::TensorPtr>(); | |||
| tensor_proto->set_raw_data(data->data().request(true).ptr, static_cast<size_t>(data->data().nbytes())); | |||
| tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes())); | |||
| auto dtype = data->data_type(); | |||
| auto shape = data->shape_c(); | |||
| tensor_proto->set_data_type(GetOnnxDataType(dtype)); | |||
| @@ -27,6 +27,7 @@ | |||
| #include "proto/onnx.pb.h" | |||
| #include "operator/ops.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "ir/tensor_py.h" | |||
| namespace mindspore { | |||
| enum OpMergeMode { | |||
| @@ -1190,7 +1191,7 @@ void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *cons | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||
| onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); | |||
| auto data = dyn_cast<tensor::Tensor>(value); | |||
| tensor_proto->set_raw_data(data->data().request(true).ptr, static_cast<size_t>(data->data().nbytes())); | |||
| tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes())); | |||
| auto dtype = data->data_type(); | |||
| auto shape = data->shape_c(); | |||
| @@ -21,6 +21,9 @@ | |||
| #include "pipeline/static_analysis/param_validator.h" | |||
| #include "operator/ops.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "ir/tensor_py.h" | |||
| using mindspore::tensor::TensorPy; | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| @@ -554,7 +557,7 @@ AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitiveP | |||
| py::tuple data_tuple = ValuePtrToPyData(input->BuildValue()); | |||
| py::array data = py::array(data_tuple); | |||
| auto tensor = std::make_shared<tensor::Tensor>(data); | |||
| auto tensor = TensorPy::MakeTensor(data); | |||
| auto ret = tensor->ToAbstract(); | |||
| ret->set_value(tensor); | |||
| MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString(); | |||
| @@ -153,7 +153,7 @@ class TensorMultiplyBase : public AnfVisitor { | |||
| } | |||
| tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||
| return tensor_ptr->data_c(writable); | |||
| return tensor_ptr->data_c(); | |||
| } | |||
| // Make a new tensor (when possible) with the same shape as of `node` | |||
| @@ -171,7 +171,7 @@ class TensorMultiplyBase : public AnfVisitor { | |||
| auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); | |||
| size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | |||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true)); | |||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c()); | |||
| if (x == nullptr) { | |||
| std::memset(data, 0, mem_size); | |||
| @@ -546,7 +546,7 @@ class ConstantDuplicateMul : public AnfVisitor { | |||
| auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape); | |||
| size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | |||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true)); | |||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c()); | |||
| memcpy(data, data_out, mem_size); | |||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||
| @@ -191,7 +191,7 @@ inline void ResetSharedOp() { | |||
| tensor::TensorPtr ConstData() { | |||
| std::vector<int> shp = {1}; | |||
| tensor::TensorPtr const_data = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | |||
| auto *val = static_cast<int32_t *>(const_data->data_c(true)); | |||
| auto *val = static_cast<int32_t *>(const_data->data_c()); | |||
| *val = 0; | |||
| return const_data; | |||
| } | |||
| @@ -267,7 +267,7 @@ CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNod | |||
| auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast<PrimitivePtr>(); | |||
| std::vector<int> shp = {1}; | |||
| tensor::TensorPtr const_data = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | |||
| auto *val = static_cast<int32_t *>(const_data->data_c(true)); | |||
| auto *val = static_cast<int32_t *>(const_data->data_c()); | |||
| *val = 0; | |||
| // for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same | |||
| // switch the other use the opposite | |||
| @@ -178,7 +178,7 @@ class ZeroLikeFillZero : public AnfVisitor { | |||
| tensor::TensorPtr new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); | |||
| size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | |||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true)); | |||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c()); | |||
| (void)memset_s(data, mem_size, 0, mem_size); | |||
| auto new_cnode = NewValueNode(new_tensor_ptr); | |||
| @@ -71,7 +71,7 @@ class SpecializeTransform { | |||
| continue; | |||
| } | |||
| if (value_args[i] != nullptr) { | |||
| auto const_tensor = *value_args[i]; | |||
| auto &const_tensor = *value_args[i]; | |||
| auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor); | |||
| AnfNodePtr arg = NewValueNode(const_tensor_ptr); | |||
| (void)mng->Replace(params[i], arg); | |||
| @@ -210,8 +210,8 @@ OperatorVector CreateSubOp(int32_t sub_value) { | |||
| OperatorName operator_name = SUB; | |||
| OperatorAttrs operator_attrs; | |||
| py::tuple tuple = py::make_tuple(sub_value); | |||
| mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tuple, kInt32); | |||
| std::vector<int64_t> tensor_data = {sub_value}; | |||
| mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tensor_data, kInt32); | |||
| ValuePtr op_param_value = MakeValue(tensor_ptr); | |||
| Attr op1_param = std::make_pair("", op_param_value); | |||
| @@ -204,8 +204,8 @@ ForwardOp CreatReduceMeanForwardOp(const std::vector<Group> &forward_group, cons | |||
| OperatorName operator1_name = REAL_DIV; | |||
| std::vector<Device> device_list = forward_group[0].GetDevicesList(); | |||
| auto divisor = static_cast<float>(device_list.size()); | |||
| py::tuple tuple = py::make_tuple(divisor); | |||
| mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tuple, dtype); | |||
| std::vector<double> tensor_data = {divisor}; | |||
| mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tensor_data, dtype); | |||
| ValuePtr op1_param_value = MakeValue(tensor_ptr); | |||
| Attr op1_param = std::make_pair("divisor", op1_param_value); | |||
| OperatorParams operator1_params = {std::make_pair(op1_param, 2)}; | |||
| @@ -156,11 +156,11 @@ void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors) | |||
| if (py::isinstance<py::float_>(item.second.attr("default_input"))) { | |||
| // convert float to tensor with shape([1]) | |||
| tensor = std::make_shared<Tensor>(kNumberTypeFloat32, std::vector<int>({1})); | |||
| *(static_cast<float *>(tensor->data_c(true))) = py::cast<float>(item.second.attr("default_input")); | |||
| *(static_cast<float *>(tensor->data_c())) = py::cast<float>(item.second.attr("default_input")); | |||
| } else if (py::isinstance<py::int_>(item.second.attr("default_input"))) { | |||
| // convert int to tensor with shape([1]) | |||
| tensor = std::make_shared<Tensor>(kNumberTypeInt32, std::vector<int>({1})); | |||
| *(static_cast<float *>(tensor->data_c(true))) = py::cast<float>(item.second.attr("default_input")); | |||
| *(static_cast<float *>(tensor->data_c())) = py::cast<float>(item.second.attr("default_input")); | |||
| } else if (py::hasattr(item.second.attr("default_input"), PYTHON_TENSOR_FLAG)) { | |||
| // cast tensor | |||
| tensor = py::cast<std::shared_ptr<Tensor>>(item.second.attr("default_input")); | |||
| @@ -330,7 +330,7 @@ py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::t | |||
| MS_LOG(EXCEPTION) << "The shape of the tensor derived is not Shape, is " << shape->ToString(); | |||
| } | |||
| auto shape_me = shape->cast<abstract::ShapePtr>()->shape(); | |||
| auto shape_ge = py::cast<Tensor>(data[*count]).shape(); | |||
| auto shape_ge = py::cast<Tensor &>(data[*count]).shape(); | |||
| if (shape_ge != shape_me) { | |||
| MS_LOG(EXCEPTION) << "The shape of the " << *count << "th tensor returned: " << shape_ge | |||
| << " is not the same as the shape of the tensor derived: " << shape_me; | |||
| @@ -44,7 +44,7 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { | |||
| indices_tensor->set_device_info(device_info); | |||
| // 2 set value of tensor | |||
| auto data_ptr = indices_tensor->data_c(true); | |||
| auto data_ptr = indices_tensor->data_c(); | |||
| MS_EXCEPTION_IF_NULL(data_ptr); | |||
| std::vector<Eigen::half> half_data; | |||
| for (size_t i = 0; i < last_dim; ++i) { | |||
| @@ -348,7 +348,7 @@ tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_pt | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr}; | |||
| tensor->set_device_info(device_info); | |||
| auto data_ptr = tensor->data_c(true); | |||
| auto data_ptr = tensor->data_c(); | |||
| MS_EXCEPTION_IF_NULL(data_ptr); | |||
| auto elem_num = values.size() * data_length; | |||
| auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num); | |||
| @@ -538,7 +538,7 @@ bool Kernel2Ms::KernelInput2MS(const std::vector<TensorPtr> &input_tensors) { | |||
| auto match_idx = match_to_rel_idxs[j]; | |||
| auto real_tensor = input_tensors[match_idx]; | |||
| auto real_size = LongToSize(real_tensor->data().nbytes()); | |||
| auto real_data = real_tensor->data_c(false); | |||
| auto real_data = real_tensor->data_c(); | |||
| MS_EXCEPTION_IF_NULL(real_data); | |||
| if (sub_ms_graph_->allTensors[cache_idx] != nullptr) { | |||
| sub_ms_graph_->allTensors[cache_idx]->data.resize(real_size); | |||
| @@ -22,6 +22,7 @@ | |||
| #include <unordered_set> | |||
| #include <algorithm> | |||
| #include "ir/tensor_py.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "utils/any.h" | |||
| #include "utils/utils.h" | |||
| @@ -51,6 +52,8 @@ | |||
| #include "pynative/pynative_execute_ge.h" | |||
| #endif | |||
| using mindspore::tensor::TensorPy; | |||
| const char SINGLE_OP_GRAPH[] = "single_op_graph"; | |||
| // primitive unable to infer value for constant input in PyNative mode | |||
| const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient"}; | |||
| @@ -171,7 +174,8 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu | |||
| py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype()); | |||
| (*out_args_list)[i] = py_args[i]; | |||
| } else { | |||
| py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype()); | |||
| double arg_value = py::cast<py::float_>(py_args[i]); | |||
| py_args[i] = std::make_shared<tensor::Tensor>(arg_value, tensor_ptr->Dtype()); | |||
| (*out_args_list)[i] = py_args[i]; | |||
| } | |||
| continue; | |||
| @@ -262,7 +266,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||
| result[i] = py::getattr(input, "data"); | |||
| } else { | |||
| auto tensor = py::cast<tensor::TensorPtr>(op_inputs[i]); | |||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data()); | |||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr()); | |||
| result[i] = new_tensor; | |||
| } | |||
| } | |||
| @@ -366,13 +370,14 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr | |||
| if (py::isinstance<tensor::Tensor>(input_object)) { | |||
| tensor_ptr = py::cast<tensor::TensorPtr>(input_object); | |||
| } else if (py::isinstance<py::float_>(input_object)) { | |||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::float_>(input_object), kFloat32); | |||
| double input_value = py::cast<py::float_>(input_object); | |||
| tensor_ptr = std::make_shared<tensor::Tensor>(input_value, kFloat32); | |||
| *tensor_mask = kValueNodeTensorMask; | |||
| } else if (py::isinstance<py::int_>(input_object)) { | |||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32); | |||
| *tensor_mask = kValueNodeTensorMask; | |||
| } else if (py::isinstance<py::array>(input_object)) { | |||
| tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr); | |||
| tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(input_object), nullptr); | |||
| } else if (py::isinstance<py::list>(input_object)) { | |||
| auto list_inputs = py::cast<py::list>(input_object); | |||
| py::tuple tuple_inputs(list_inputs.size()); | |||
| @@ -26,6 +26,7 @@ | |||
| #include <stack> | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/numpy.h" | |||
| #include "pynative/base.h" | |||
| #include "utils/context/ms_context.h" | |||
| @@ -28,9 +28,12 @@ | |||
| #include "pipeline/parse/data_converter.h" | |||
| #include "pipeline/static_analysis/prim.h" | |||
| #include "session/session_factory.h" | |||
| #include "ir/tensor_py.h" | |||
| const char SINGLE_OP_GRAPH[] = "single_op_graph"; | |||
| using mindspore::tensor::TensorPy; | |||
| namespace mindspore { | |||
| namespace pynative { | |||
| using MeTensor = mindspore::tensor::Tensor; | |||
| @@ -56,15 +59,15 @@ MeTensorPtr ConvertPyObjToTensor(const py::object &obj) { | |||
| if (py::isinstance<MeTensor>(obj)) { | |||
| me_tensor_ptr = py::cast<MeTensorPtr>(obj); | |||
| } else if (py::isinstance<py::tuple>(obj)) { | |||
| me_tensor_ptr = std::make_shared<MeTensor>(py::cast<py::tuple>(obj), nullptr); | |||
| me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast<py::tuple>(obj)), nullptr); | |||
| } else if (py::isinstance<py::float_>(obj)) { | |||
| me_tensor_ptr = std::make_shared<MeTensor>(py::cast<py::float_>(obj), nullptr); | |||
| me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast<py::float_>(obj)), nullptr); | |||
| } else if (py::isinstance<py::int_>(obj)) { | |||
| me_tensor_ptr = std::make_shared<MeTensor>(py::cast<py::int_>(obj), nullptr); | |||
| me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast<py::int_>(obj)), nullptr); | |||
| } else if (py::isinstance<py::list>(obj)) { | |||
| me_tensor_ptr = std::make_shared<MeTensor>(py::cast<py::list>(obj), nullptr); | |||
| me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast<py::list>(obj)), nullptr); | |||
| } else if (py::isinstance<py::array>(obj)) { | |||
| me_tensor_ptr = std::make_shared<MeTensor>(py::cast<py::array>(obj), nullptr); | |||
| me_tensor_ptr = TensorPy::MakeTensor(py::cast<py::array>(obj), nullptr); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; | |||
| } | |||
| @@ -16,6 +16,7 @@ | |||
| #include "session/ascend_inference_session.h" | |||
| #include "operator/ops.h" | |||
| #include "ir/tensor.h" | |||
| #include "ir/tensor_py.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "device/kernel_runtime.h" | |||
| @@ -26,6 +27,8 @@ | |||
| #include "utils/config_manager.h" | |||
| #include "utils/base_ref_extends.h" | |||
| using mindspore::tensor::TensorPy; | |||
| namespace mindspore { | |||
| namespace session { | |||
| void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| @@ -51,7 +54,7 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k | |||
| auto py_param = param_value->value(); | |||
| MS_EXCEPTION_IF_NULL(py_param); | |||
| py::array py_array = py_param.cast<py::array>(); | |||
| tensor = std::make_shared<tensor::Tensor>(py_array); | |||
| tensor = TensorPy::MakeTensor(py_array); | |||
| } else { | |||
| tensor = inputs[no_weight_input++]; | |||
| } | |||
| @@ -78,7 +81,7 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c(false))) { | |||
| tensor->data_c())) { | |||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | |||
| } | |||
| } | |||
| @@ -989,7 +989,7 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true | |||
| MS_EXCEPTION_IF_NULL(condition_graph); | |||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int>{1}); | |||
| int32_t *val = nullptr; | |||
| val = static_cast<int32_t *>(tensor->data_c(true)); | |||
| val = static_cast<int32_t *>(tensor->data_c()); | |||
| MS_EXCEPTION_IF_NULL(val); | |||
| *val = 0; | |||
| auto value_node = std::make_shared<ValueNode>(tensor); | |||
| @@ -1523,7 +1523,7 @@ void AscendSession::SyncInitialTenosrToDevice() { | |||
| auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); | |||
| MS_EXCEPTION_IF_NULL(addr); | |||
| if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size, | |||
| front_tensor->data_type(), front_tensor->data_c(false))) { | |||
| front_tensor->data_type(), front_tensor->data_c())) { | |||
| MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; | |||
| } | |||
| } | |||
| @@ -129,7 +129,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c(false))) { | |||
| tensor->data_c())) { | |||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | |||
| } | |||
| } | |||
| @@ -96,8 +96,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne | |||
| tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); | |||
| tensor->set_dirty(false); | |||
| } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c(true))) { | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { | |||
| MS_LOG(INFO) << "output sync device to host error!!!"; | |||
| tensor->set_dirty(false); | |||
| } | |||
| @@ -218,7 +217,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto | |||
| } | |||
| auto tensor = (*inputs_params)[0]; | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| auto *val = static_cast<int32_t *>(tensor->data_c(true)); | |||
| auto *val = static_cast<int32_t *>(tensor->data_c()); | |||
| MS_EXCEPTION_IF_NULL(val); | |||
| *val = 0; | |||
| tensor->set_dirty(true); | |||
| @@ -720,7 +719,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c(false))) { | |||
| tensor->data_c())) { | |||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | |||
| } | |||
| } | |||
| @@ -815,7 +814,7 @@ void SessionBasic::Summary(KernelGraph *graph) { | |||
| continue; | |||
| } | |||
| if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()), | |||
| tensor->data_type(), tensor->data_c(true))) { | |||
| tensor->data_type(), tensor->data_c())) { | |||
| MS_LOG(ERROR) << "Failed to sync output from device to host."; | |||
| } | |||
| tensor->set_dirty(false); | |||
| @@ -342,7 +342,7 @@ MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const | |||
| MeTensor me_tensor(me_type, me_dims); | |||
| // Get the writable data pointer of the tensor and cast it to its data type | |||
| auto me_data_ptr = reinterpret_cast<uint8_t *>(me_tensor.data_c(true)); | |||
| auto me_data_ptr = reinterpret_cast<uint8_t *>(me_tensor.data_c()); | |||
| size_t me_data_size = static_cast<size_t>(me_tensor.data().nbytes()); | |||
| MS_EXCEPTION_IF_NULL(me_data_ptr); | |||
| MS_EXCEPTION_IF_NULL(ge_tensor); | |||
| @@ -579,11 +579,12 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) { | |||
| } | |||
| tensor::TensorPtr tensor = nullptr; | |||
| if (scalar->isa<FloatImm>()) { | |||
| tensor = std::make_shared<tensor::Tensor>(py::float_(GetValue<float>(scalar)), kFloat32); | |||
| tensor = std::make_shared<tensor::Tensor>(static_cast<double>(GetValue<float>(scalar)), kFloat32); | |||
| } else if (scalar->isa<IntergerImm>()) { | |||
| tensor = std::make_shared<tensor::Tensor>(py::int_(GetValue<int>(scalar)), kInt32); | |||
| tensor = std::make_shared<tensor::Tensor>(static_cast<int64_t>(GetValue<int>(scalar)), kInt32); | |||
| } else if (scalar->isa<BoolImm>()) { | |||
| tensor = std::make_shared<tensor::Tensor>(py::array(py::bool_(GetValue<bool>(scalar))), kBool); | |||
| const int64_t bool_value = GetValue<bool>(scalar) ? 1 : 0; | |||
| tensor = std::make_shared<tensor::Tensor>(bool_value, kBool); | |||
| } else { | |||
| auto type = scalar->type(); | |||
| auto type_str = (type == nullptr) ? "nullptr" : type->ToString(); | |||
| @@ -22,12 +22,14 @@ | |||
| #include <vector> | |||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | |||
| #include "ir/tensor.h" | |||
| #include "ir/tensor_py.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "operator/ops.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "proto/onnx.pb.h" | |||
| #include "utils/log_adapter.h" | |||
| using mindspore::tensor::TensorPy; | |||
| using std::string; | |||
| namespace mindspore { | |||
| @@ -117,11 +119,11 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, cons | |||
| if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { | |||
| const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; | |||
| std::string initial_data = initialize_proto.raw_data(); | |||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c(true)); | |||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c()); | |||
| MS_EXCEPTION_IF_NULL(tensor_data_buf); | |||
| memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size()); | |||
| py::array array_data = tensor_info->data(); | |||
| py::array array_data = TensorPy::AsNumpy(*tensor_info); | |||
| ParamValuePyPtr para_value_ptr = std::make_shared<ParamValuePy>(); | |||
| MS_EXCEPTION_IF_NULL(para_value_ptr); | |||
| para_value_ptr->set_value(array_data); | |||
| @@ -249,7 +251,7 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node | |||
| } | |||
| tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape); | |||
| const std::string &tensor_buf = attr_tensor.raw_data(); | |||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c(true)); | |||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c()); | |||
| memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); | |||
| auto new_value_node = NewValueNode(MakeValue(tensor_info)); | |||
| MS_EXCEPTION_IF_NULL(new_value_node); | |||
| @@ -87,7 +87,7 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co | |||
| const size_t &memory_size) { | |||
| MS_EXCEPTION_IF_NULL(str_data_ptr); | |||
| MS_EXCEPTION_IF_NULL(print_tensor); | |||
| auto *tensor_data_ptr = static_cast<uint8_t *>(print_tensor->data_c(true)); | |||
| auto *tensor_data_ptr = static_cast<uint8_t *>(print_tensor->data_c()); | |||
| MS_EXCEPTION_IF_NULL(tensor_data_ptr); | |||
| auto cp_ret = | |||
| memcpy_s(tensor_data_ptr, static_cast<size_t>(print_tensor->data().nbytes()), str_data_ptr, memory_size); | |||
| @@ -61,9 +61,9 @@ class Tensor(Tensor_): | |||
| if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']): | |||
| input_data = np.ascontiguousarray(input_data) | |||
| if dtype is None: | |||
| super(Tensor, self).__init__(input_data) | |||
| Tensor_.__init__(self, input_data) | |||
| else: | |||
| super(Tensor, self).__init__(input_data, dtype) | |||
| Tensor_.__init__(self, input_data, dtype) | |||
| self._virtual_flag = False | |||
| self._init_flag = False | |||
| @@ -55,6 +55,7 @@ def rmsprop_numpy(variable, gradients, mean_square, moment, | |||
| mean_square = mean_square * decay + (1.0 - decay) * gradients * gradients | |||
| moment = momentum * moment + learning_rate / np.sqrt(mean_square + epsilon) * gradients | |||
| variable = variable - moment | |||
| return variable, gradients, mean_square, moment | |||
| def rmspropcented_numpy(variable, gradients, mean_gradients, mean_square, moment, | |||
| @@ -64,7 +65,7 @@ def rmspropcented_numpy(variable, gradients, mean_gradients, mean_square, moment | |||
| moment = momentum * moment + learning_rate / np.sqrt( | |||
| mean_square - mean_gradients * mean_gradients + epsilon) * gradients | |||
| variable = variable - moment | |||
| return variable, gradients, mean_gradients, mean_square, moment | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @@ -85,12 +86,14 @@ def test_rmsprop(): | |||
| moment_ms = Tensor(moment_np) | |||
| if centered: | |||
| variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np = \ | |||
| rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, | |||
| learning_rate, decay, momentum, epsilon) | |||
| net = NetCenteredRMSProp(learning_rate, decay, momentum, epsilon) | |||
| _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms) | |||
| else: | |||
| variable_np, gradients_np, mean_square_np, moment_np = \ | |||
| rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, | |||
| learning_rate, decay, momentum, epsilon) | |||
| net = NetRMSProp(learning_rate, decay, momentum, epsilon) | |||
| @@ -136,11 +139,13 @@ def test_rmspropcenter(): | |||
| moment_ms = Tensor(moment_np) | |||
| if centered: | |||
| variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np = \ | |||
| rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, | |||
| learning_rate, decay, momentum, epsilon) | |||
| net = NetCenteredRMSProp(learning_rate, decay, momentum, epsilon) | |||
| _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms) | |||
| else: | |||
| variable_np, gradients_np, mean_square_np, moment_np = \ | |||
| rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, | |||
| learning_rate, decay, momentum, epsilon) | |||
| net = NetRMSProp(learning_rate, decay, momentum, epsilon) | |||
| @@ -22,6 +22,9 @@ | |||
| #include "securec/include/securec.h" | |||
| #include "ir/tensor.h" | |||
| #include "ir/tensor_py.h" | |||
| using mindspore::tensor::TensorPy; | |||
| namespace mindspore { | |||
| namespace tensor { | |||
| @@ -90,9 +93,7 @@ TEST_F(TestMetaTensor, EqualTest) { | |||
| class TestTensor : public UT::Common { | |||
| public: | |||
| TestTensor() {} | |||
| virtual void SetUp() { | |||
| UT::InitPythonPath(); | |||
| } | |||
| virtual void SetUp() { UT::InitPythonPath(); } | |||
| }; | |||
| py::array_t<float, py::array::c_style> BuildInputTensor() { | |||
| @@ -124,7 +125,7 @@ TEST_F(TestTensor, PyArrayScalarTest) { | |||
| TEST_F(TestTensor, InitScalarTest) { | |||
| std::vector<int> dimensions; | |||
| Tensor tensor(TypeId::kNumberTypeInt64, dimensions); | |||
| uint8_t *data_buf = reinterpret_cast<uint8_t *>(tensor.data_c(true)); | |||
| uint8_t *data_buf = reinterpret_cast<uint8_t *>(tensor.data_c()); | |||
| int64_t num = 1; | |||
| errno_t ret = memcpy_s(data_buf, sizeof(int64_t), &num, sizeof(int64_t)); | |||
| @@ -172,9 +173,9 @@ TEST_F(TestTensor, InitTensorPtrTest) { | |||
| } | |||
| TEST_F(TestTensor, InitByTupleTest) { | |||
| py::tuple dimensions = py::make_tuple(2, 3, 4); | |||
| const std::vector<int> shape = {2, 3, 4}; | |||
| TypePtr data_type = kFloat32; | |||
| Tensor tuple_tensor = Tensor(data_type, dimensions); | |||
| Tensor tuple_tensor(data_type->type_id(), shape); | |||
| ASSERT_EQ(2, tuple_tensor.DimensionSize(0)); | |||
| ASSERT_EQ(3, tuple_tensor.DimensionSize(1)); | |||
| ASSERT_EQ(4, tuple_tensor.DimensionSize(2)); | |||
| @@ -184,8 +185,8 @@ TEST_F(TestTensor, InitByTupleTest) { | |||
| ASSERT_EQ(TypeId::kNumberTypeFloat32, tuple_tensor.data_type()); | |||
| py::tuple tuple = py::make_tuple(1.0, 2.0, 3, 4, 5, 6); | |||
| TensorPtr tensor = std::make_shared<Tensor>(tuple, kFloat64); | |||
| py::array array = tensor->data(); | |||
| TensorPtr tensor = TensorPy::MakeTensor(py::array(tuple), kFloat64); | |||
| py::array array = TensorPy::AsNumpy(*tensor); | |||
| std::cout << "Dim: " << array.ndim() << std::endl; | |||
| ASSERT_EQ(1, array.ndim()); | |||
| @@ -203,24 +204,24 @@ TEST_F(TestTensor, InitByTupleTest) { | |||
| TEST_F(TestTensor, EqualTest) { | |||
| py::tuple tuple = py::make_tuple(1, 2, 3, 4, 5, 6); | |||
| TensorPtr tensor_int8 = std::make_shared<Tensor>(tuple, kInt8); | |||
| TensorPtr tensor_int8 = TensorPy::MakeTensor(py::array(tuple), kInt8); | |||
| ASSERT_TRUE(*tensor_int8 == *tensor_int8); | |||
| ASSERT_EQ(TypeId::kNumberTypeInt8, tensor_int8->data_type_c()); | |||
| TensorPtr tensor_int16 = std::make_shared<Tensor>(tuple, kInt16); | |||
| TensorPtr tensor_int16 = TensorPy::MakeTensor(py::array(tuple), kInt16); | |||
| ASSERT_EQ(TypeId::kNumberTypeInt16, tensor_int16->data_type_c()); | |||
| TensorPtr tensor_int32 = std::make_shared<Tensor>(tuple, kInt32); | |||
| TensorPtr tensor_int32 = TensorPy::MakeTensor(py::array(tuple), kInt32); | |||
| ASSERT_EQ(TypeId::kNumberTypeInt32, tensor_int32->data_type_c()); | |||
| TensorPtr tensor_float16 = std::make_shared<Tensor>(tuple, kFloat16); | |||
| TensorPtr tensor_float16 = TensorPy::MakeTensor(py::array(tuple), kFloat16); | |||
| ASSERT_EQ(TypeId::kNumberTypeFloat16, tensor_float16->data_type_c()); | |||
| TensorPtr tensor_float32 = std::make_shared<Tensor>(tuple, kFloat32); | |||
| TensorPtr tensor_float32 = TensorPy::MakeTensor(py::array(tuple), kFloat32); | |||
| ASSERT_EQ(TypeId::kNumberTypeFloat32, tensor_float32->data_type_c()); | |||
| TensorPtr tensor_float64 = std::make_shared<Tensor>(tuple, kFloat64); | |||
| TensorPtr tensor_float64 = TensorPy::MakeTensor(py::array(tuple), kFloat64); | |||
| ASSERT_EQ(TypeId::kNumberTypeFloat64, tensor_float64->data_type_c()); | |||
| } | |||
| @@ -247,7 +248,7 @@ TEST_F(TestTensor, PyArrayTest) { | |||
| TEST_F(TestTensor, InitByFloatArrayDataCTest) { | |||
| // Init tensor data by py::array_t<float> | |||
| auto tensor = std::make_shared<Tensor>(BuildInputTensor()); | |||
| auto tensor = TensorPy::MakeTensor(BuildInputTensor()); | |||
| // Print some information of the tensor | |||
| std::cout << "Datatype: " << tensor->data_type() << std::endl; | |||
| @@ -269,7 +270,7 @@ TEST_F(TestTensor, InitByFloatArrayDataCTest) { | |||
| TEST_F(TestTensor, InitByFloatArrayDataTest) { | |||
| // Init tensor data by py::array_t<float> | |||
| TensorPtr tensor = std::make_shared<Tensor>(BuildInputTensor()); | |||
| TensorPtr tensor = TensorPy::MakeTensor(BuildInputTensor()); | |||
| // Print some information of the tensor | |||
| std::cout << "Datatype: " << tensor->data_type() << std::endl; | |||
| @@ -291,7 +292,7 @@ TEST_F(TestTensor, InitByFloatArrayDataTest) { | |||
| // Print each elements | |||
| std::cout << "Elements: " << std::endl; | |||
| py::array_t<float> data = (py::array_t<float>)tensor->data(); | |||
| py::array_t<float> data = py::cast<py::array_t<float>>(TensorPy::AsNumpy(*tensor)); | |||
| auto array = data.unchecked<2>(); | |||
| for (int i = 0; i < array.shape(0); i++) { | |||
| for (int j = 0; j < array.shape(1); j++) { | |||
| @@ -319,17 +320,17 @@ TEST_F(TestTensor, TensorDataTest) { | |||
| float ge_tensor_data[] = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6}; | |||
| // Create a Tensor with wanted data type and shape | |||
| Tensor tensor = Tensor(TypeId::kNumberTypeFloat32, std::vector<int>({2, 3})); | |||
| Tensor tensor(TypeId::kNumberTypeFloat32, std::vector<int>({2, 3})); | |||
| // Get the writable data pointer from the tensor | |||
| float *me_tensor_data = reinterpret_cast<float *>(tensor.data_c(true)); | |||
| float *me_tensor_data = reinterpret_cast<float *>(tensor.data_c()); | |||
| // Copy data from buffer to tensor's data | |||
| errno_t ret = memcpy_s(me_tensor_data, tensor.data().nbytes(), ge_tensor_data, sizeof(ge_tensor_data)); | |||
| ASSERT_EQ(0, ret); | |||
| // Testify if the data has been copied to the tensor data | |||
| py::array_t<float> data = (py::array_t<float>)tensor.data(); | |||
| py::array_t<float> data = py::cast<py::array_t<float>>(TensorPy::AsNumpy(tensor)); | |||
| auto array = data.mutable_unchecked(); | |||
| for (int i = 0; i < array.shape(0); i++) { | |||
| for (int j = 0; j < array.shape(1); j++) { | |||
| @@ -340,5 +341,17 @@ TEST_F(TestTensor, TensorDataTest) { | |||
| } | |||
| } | |||
| TEST_F(TestTensor, TensorPyCast) { | |||
| std::vector<int> shape{2, 3, 4, 5}; | |||
| py::tuple py_tuple = py::make_tuple(std::make_shared<Tensor>(kNumberTypeFloat32, shape)); | |||
| auto shape1 = py::cast<Tensor &>(py_tuple[0]).shape(); | |||
| const py::tuple &t = py_tuple; | |||
| auto shape2 = py::cast<const Tensor &>(t[0]).shape(); | |||
| auto shape3 = py::cast<Tensor &>(t[0]).shape(); | |||
| ASSERT_EQ(shape, shape1); | |||
| ASSERT_EQ(shape, shape2); | |||
| ASSERT_EQ(shape, shape3); | |||
| } | |||
| } // namespace tensor | |||
| } // namespace mindspore | |||
| @@ -60,15 +60,9 @@ CNodePtr Make_Node(Shape x, Shape y, Shape out, int condition = 0) { | |||
| BaseShapePtr shape1 = std::make_shared<abstract::Shape>(x); | |||
| BaseShapePtr shape2 = std::make_shared<abstract::Shape>(y); | |||
| BaseShapePtr shape3 = std::make_shared<abstract::Shape>(out); | |||
| std::shared_ptr<tensor::Tensor> inputs_x = std::make_shared<tensor::Tensor>(); | |||
| inputs_x->set_data_type(kNumberTypeInt32); | |||
| inputs_x->set_shape(x); | |||
| std::shared_ptr<tensor::Tensor> inputs_y = std::make_shared<tensor::Tensor>(); | |||
| inputs_y->set_data_type(kNumberTypeInt32); | |||
| inputs_y->set_shape(y); | |||
| std::shared_ptr<tensor::Tensor> inputs_out = std::make_shared<tensor::Tensor>(); | |||
| inputs_out->set_data_type(kNumberTypeInt32); | |||
| inputs_out->set_shape(out); | |||
| std::shared_ptr<tensor::Tensor> inputs_x = std::make_shared<tensor::Tensor>(kNumberTypeInt32, x); | |||
| std::shared_ptr<tensor::Tensor> inputs_y = std::make_shared<tensor::Tensor>(kNumberTypeInt32, y); | |||
| std::shared_ptr<tensor::Tensor> inputs_out = std::make_shared<tensor::Tensor>(kNumberTypeInt32, out); | |||
| AbstractBasePtr abstract1 = abstract::FromValue(inputs_x, true); | |||
| AbstractBasePtr abstract2 = abstract::FromValue(inputs_y, true); | |||
| AbstractBasePtr abstract3 = abstract::FromValue(inputs_out, true); | |||
| @@ -127,21 +121,11 @@ FuncGraphManagerPtr Make_Manager(int condition = 0) { | |||
| ParameterPtr param1 = func_graph->add_parameter(); | |||
| ParameterPtr param2 = func_graph->add_parameter(); | |||
| ParameterPtr param3 = func_graph->add_parameter(); | |||
| std::shared_ptr<tensor::Tensor> inputs_x_dim = std::make_shared<tensor::Tensor>(); | |||
| inputs_x_dim->set_data_type(kNumberTypeInt32); | |||
| inputs_x_dim->set_shape(inputs_x); | |||
| std::shared_ptr<tensor::Tensor> inputs_y_dim = std::make_shared<tensor::Tensor>(); | |||
| inputs_y_dim->set_data_type(kNumberTypeInt32); | |||
| inputs_y_dim->set_shape(inputs_y); | |||
| std::shared_ptr<tensor::Tensor> inputs_z_dim = std::make_shared<tensor::Tensor>(); | |||
| inputs_z_dim->set_data_type(kNumberTypeInt32); | |||
| inputs_z_dim->set_shape(inputs_z); | |||
| std::shared_ptr<tensor::Tensor> inputs_out1_dim = std::make_shared<tensor::Tensor>(); | |||
| inputs_out1_dim->set_data_type(kNumberTypeInt32); | |||
| inputs_out1_dim->set_shape(outputs_1); | |||
| std::shared_ptr<tensor::Tensor> inputs_out2_dim = std::make_shared<tensor::Tensor>(); | |||
| inputs_out2_dim->set_data_type(kNumberTypeInt32); | |||
| inputs_out2_dim->set_shape(outputs_2); | |||
| std::shared_ptr<tensor::Tensor> inputs_x_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_x); | |||
| std::shared_ptr<tensor::Tensor> inputs_y_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_y); | |||
| std::shared_ptr<tensor::Tensor> inputs_z_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, inputs_z); | |||
| std::shared_ptr<tensor::Tensor> inputs_out1_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, outputs_1); | |||
| std::shared_ptr<tensor::Tensor> inputs_out2_dim = std::make_shared<tensor::Tensor>(kNumberTypeInt32, outputs_2); | |||
| AbstractBasePtr abstract_x = abstract::FromValue(inputs_x_dim, true); | |||
| AbstractBasePtr abstract_y = abstract::FromValue(inputs_y_dim, true); | |||
| AbstractBasePtr abstract_z = abstract::FromValue(inputs_z_dim, true); | |||
| @@ -113,12 +113,8 @@ TEST_F(TestData, test_build_shape) { | |||
| std::vector<int> weight1_dims = {2, 20, 5, 5}; | |||
| std::vector<int> weight2_dims = {2, 2, 5, 5}; | |||
| tensor::TensorPtr weight1 = std::make_shared<tensor::Tensor>(); | |||
| weight1->set_data_type(kNumberTypeInt32); | |||
| weight1->set_shape(weight1_dims); | |||
| tensor::TensorPtr weight2 = std::make_shared<tensor::Tensor>(); | |||
| weight2->set_data_type(kNumberTypeInt32); | |||
| weight2->set_shape(weight2_dims); | |||
| tensor::TensorPtr weight1 = std::make_shared<tensor::Tensor>(kNumberTypeInt32, weight1_dims); | |||
| tensor::TensorPtr weight2 = std::make_shared<tensor::Tensor>(kNumberTypeInt32, weight2_dims); | |||
| AbstractBasePtr abstract_weight1 = FromValue(weight1, true); | |||
| AbstractBasePtr abstract_weight2 = FromValue(weight2, true); | |||
| @@ -104,7 +104,7 @@ TEST_F(TestHWConstInputToTensorInput, test_value_tuple_tensor_input) { | |||
| EXPECT_TRUE(IsValueNode<tensor::Tensor>(input1)); | |||
| auto tensor = input1->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>(); | |||
| ASSERT_TRUE(tensor != nullptr); | |||
| auto data = tensor->data_c(false); | |||
| auto data = tensor->data_c(); | |||
| EXPECT_EQ(std::vector<int>((int *)data, (int *)data + 4), std::vector<int>({2, 4, 2, 2})); | |||
| } | |||
| } // namespace opt | |||
| @@ -706,7 +706,7 @@ TEST_F(TestConvert, TestConvertTensor) { | |||
| auto type_id = kNumberTypeFloat32; | |||
| MeTensor me_tensor(type_id, dims); | |||
| // Get the writable data pointer of the tensor and cast it to its data type | |||
| uint8_t* me_data_ptr = reinterpret_cast<uint8_t*>(me_tensor.data_c(true)); | |||
| uint8_t* me_data_ptr = reinterpret_cast<uint8_t*>(me_tensor.data_c()); | |||
| // Copy or use the writable data pointer of the ME tensor | |||
| memcpy_s(me_data_ptr, me_tensor.data().nbytes(), data, 12 * sizeof(float)); | |||
| auto me_tensor_ptr = std::make_shared<MeTensor>(me_tensor); | |||
| @@ -18,6 +18,7 @@ | |||
| #include <memory> | |||
| #include "common/common_test.h" | |||
| #include "ir/dtype.h" | |||
| #include "ir/tensor_py.h" | |||
| #include "transform/transform_base_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| #include "pipeline/static_analysis/static_analysis.h" | |||
| @@ -35,6 +36,8 @@ | |||
| #define private public | |||
| #include "transform/graph_runner.h" | |||
| using mindspore::tensor::TensorPy; | |||
| namespace mindspore { | |||
| namespace transform { | |||
| class TestGraphRunner : public UT::Common { | |||
| @@ -70,7 +73,7 @@ std::shared_ptr<DfGraphConvertor> MakeGeGraph() { | |||
| return std::make_shared<DfGraphConvertor>(anf_graph); | |||
| } | |||
| namespace { | |||
| std::shared_ptr<std::vector<MeTensorPtr>> DoExecGraph(const std::vector<MeTensorPtr>& inputs) { | |||
| std::shared_ptr<std::vector<MeTensorPtr>> DoExecGraph(const std::vector<MeTensorPtr> &inputs) { | |||
| std::vector<GeTensorPtr> ge_tensor_ptrs = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); | |||
| std::vector<GeTensorPtr> ge_outputs; | |||
| @@ -109,7 +112,7 @@ TEST_F(TestGraphRunner, TestGeTensorConstructor) { | |||
| MeTensor tensor = MeTensor(TypeId::kNumberTypeFloat32, std::vector<int>({1, 2, 3})); | |||
| // Get the writable data pointer from the tensor | |||
| float* me_tensor_data = reinterpret_cast<float*>(tensor.data_c(true)); | |||
| float *me_tensor_data = reinterpret_cast<float *>(tensor.data_c()); | |||
| // Copy data from buffer to tensor's data | |||
| memcpy_s(me_tensor_data, static_cast<size_t>(tensor.data().nbytes()), ge_tensor_data, sizeof(ge_tensor_data)); | |||
| @@ -119,11 +122,11 @@ TEST_F(TestGraphRunner, TestGeTensorConstructor) { | |||
| py::tuple py_tuple = | |||
| py::make_tuple(py::make_tuple(py::make_tuple(1.1f, 2.2f, 3.3f), py::make_tuple(4.4f, 5.5f, 6.6f))); | |||
| py::array my_arry = py::array(py_tuple).attr("astype").cast<py::function>()("float32").cast<py::array>(); | |||
| MeTensor tensor_tuple = MeTensor(my_arry, kFloat32); | |||
| PrintMeTensor(&tensor_tuple); | |||
| auto tensor_tuple = TensorPy::MakeTensor(my_arry, kFloat32); | |||
| PrintMeTensor(tensor_tuple.get()); | |||
| py::array tensor_array = tensor.data(); | |||
| py::array tensor_tuple_array = tensor_tuple.data(); | |||
| py::array tensor_array = TensorPy::AsNumpy(tensor); | |||
| py::array tensor_tuple_array = TensorPy::AsNumpy(*tensor_tuple); | |||
| assert(memcmp(ge_tensor_data, tensor_array.data(), sizeof(ge_tensor_data)) == 0); | |||
| assert(memcmp(ge_tensor_data, tensor_tuple_array.data(), sizeof(ge_tensor_data)) == 0); | |||
| } | |||
| @@ -131,7 +134,7 @@ TEST_F(TestGraphRunner, TestGeTensorConstructor) { | |||
| #if (!defined ENABLE_GE) | |||
| TEST_F(TestGraphRunner, TestRunGraphException) { | |||
| DfGraphManager& graph_manager = DfGraphManager::GetInstance(); | |||
| DfGraphManager &graph_manager = DfGraphManager::GetInstance(); | |||
| graph_manager.ClearGraph(); | |||
| std::map<string, MeTensorPtr> dict; | |||
| @@ -167,7 +170,7 @@ TEST_F(TestGraphRunner, TestRunGraphException) { | |||
| } | |||
| TEST_F(TestGraphRunner, TestRunGraph) { | |||
| DfGraphManager& graph_manager = DfGraphManager::GetInstance(); | |||
| DfGraphManager &graph_manager = DfGraphManager::GetInstance(); | |||
| graph_manager.ClearGraph(); | |||
| std::shared_ptr<DfGraphConvertor> convertor = MakeGeGraph(); | |||
| @@ -183,7 +186,7 @@ TEST_F(TestGraphRunner, TestRunGraph) { | |||
| py::make_tuple(py::make_tuple(py::make_tuple(1.0, 2.0, 3.0, 4.0), py::make_tuple(4.0, 5.0, 6.0, 7.0))), | |||
| py::make_tuple(py::make_tuple(py::make_tuple(1.0, 2.0, 3.0, 4.0), py::make_tuple(4.0, 5.0, 6.0, 7.0)))); | |||
| py::array array = py::array(tuple); | |||
| MeTensorPtr me_tensor_ptr = std::make_shared<MeTensor>(array, type_id); | |||
| MeTensorPtr me_tensor_ptr = TensorPy::MakeTensor(array, type_id); | |||
| MS_LOG(INFO) << "inputs me tensor data is: "; | |||
| PrintMeTensor(&(*me_tensor_ptr)); | |||
| @@ -204,7 +207,7 @@ TEST_F(TestGraphRunner, TestRunGraph) { | |||
| } | |||
| TEST_F(TestGraphRunner, TestAPI) { | |||
| DfGraphManager& graph_manager = DfGraphManager::GetInstance(); | |||
| DfGraphManager &graph_manager = DfGraphManager::GetInstance(); | |||
| graph_manager.ClearGraph(); | |||
| std::shared_ptr<DfGraphConvertor> convertor = MakeGeGraph(); | |||
| @@ -16,6 +16,9 @@ | |||
| #include <iostream> | |||
| #include "common/common_test.h" | |||
| #include "transform/transform_base_test.h" | |||
| #include "ir/tensor_py.h" | |||
| using mindspore::tensor::TensorPy; | |||
| namespace mindspore { | |||
| namespace transform { | |||
| @@ -55,10 +58,10 @@ void PrintMeTensor(MeTensor* tensor) { | |||
| } | |||
| std::cout << "the py::str() data is: " << std::endl; | |||
| py::array tensor_data = (*tensor).data(); | |||
| py::array tensor_data = TensorPy::AsNumpy(*tensor); | |||
| std::cout << std::string(py::str(tensor_data)) << std::endl; | |||
| std::cout << "tensor dtype is: " << std::string(tensor->data().dtype().str()) << std::endl; | |||
| std::cout << "tensor dtype is: " << std::string(tensor_data.dtype().str()) << std::endl; | |||
| } | |||
| FuncGraphPtr MakeFuncGraph(const PrimitivePtr prim, unsigned int nparam) { | |||
| @@ -73,7 +76,7 @@ FuncGraphPtr MakeFuncGraph(const PrimitivePtr prim, unsigned int nparam) { | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(prim)); | |||
| for (unsigned int i = 0; i < nparam; i++) { | |||
| if ((prim->name() == "ScalarSummary" || prim->name() == "TensorSummary" || | |||
| if ((prim->name() == "ScalarSummary" || prim->name() == "TensorSummary" || | |||
| prim->name() == "ImageSummary" || prim->name() == "HistogramSummary") && | |||
| i == 0) { | |||
| auto input = NewValueNode("testSummary"); | |||