Browse Source

!167 Serving, opt predict callback

From: @xu-yfei
Reviewed-by: @zhoufeng54,@zhangyinxia
Signed-off-by: @zhangyinxia
tags/v1.2.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
4fa4109031
49 changed files with 1119 additions and 342 deletions
  1. +1
    -1
      example/matmul_distributed/export_model/distributed_inference.py
  2. +1
    -1
      example/matmul_distributed/export_model/net.py
  3. +3
    -3
      mindspore_serving/ccsrc/common/grpc_client.h
  4. +28
    -2
      mindspore_serving/ccsrc/common/proto_tensor.cc
  5. +6
    -2
      mindspore_serving/ccsrc/common/proto_tensor.h
  6. +39
    -17
      mindspore_serving/ccsrc/master/dispacther.cc
  7. +5
    -2
      mindspore_serving/ccsrc/master/dispacther.h
  8. +3
    -47
      mindspore_serving/ccsrc/master/grpc/grpc_process.cc
  9. +1
    -1
      mindspore_serving/ccsrc/master/grpc/grpc_process.h
  10. +2
    -5
      mindspore_serving/ccsrc/master/grpc/grpc_server.h
  11. +1
    -1
      mindspore_serving/ccsrc/master/notify_worker/base_notify.h
  12. +10
    -5
      mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc
  13. +4
    -4
      mindspore_serving/ccsrc/master/notify_worker/grpc_notify.h
  14. +2
    -2
      mindspore_serving/ccsrc/master/notify_worker/local_notify.cc
  15. +1
    -1
      mindspore_serving/ccsrc/master/notify_worker/local_notify.h
  16. +28
    -53
      mindspore_serving/ccsrc/master/restful/http_process.cc
  17. +0
    -1
      mindspore_serving/ccsrc/master/restful/http_process.h
  18. +4
    -4
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h
  19. +1
    -1
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc
  20. +1
    -1
      mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h
  21. +1
    -1
      mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc
  22. +1
    -1
      mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h
  23. +8
    -14
      mindspore_serving/ccsrc/worker/grpc/worker_process.cc
  24. +6
    -6
      mindspore_serving/ccsrc/worker/grpc/worker_process.h
  25. +2
    -5
      mindspore_serving/ccsrc/worker/grpc/worker_server.h
  26. +12
    -12
      mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc
  27. +4
    -11
      mindspore_serving/ccsrc/worker/worker.cc
  28. +1
    -1
      mindspore_serving/ccsrc/worker/worker.h
  29. +2
    -5
      tests/ut/cpp/common/test_servable_common.h
  30. +368
    -1
      tests/ut/python/tests/test_mater_worker_client.py
  31. +85
    -34
      tests/ut/stub/cxx_api/context.cc
  32. +9
    -9
      tests/ut/stub/cxx_api/graph/ascend/ascend_graph_impl.cc
  33. +7
    -7
      tests/ut/stub/cxx_api/graph/ascend/ascend_graph_impl.h
  34. +25
    -3
      tests/ut/stub/cxx_api/model/model.cc
  35. +4
    -1
      tests/ut/stub/cxx_api/model/model_impl.h
  36. +5
    -0
      tests/ut/stub/cxx_api/model/ms/ms_model.cc
  37. +8
    -4
      tests/ut/stub/cxx_api/serialization.cc
  38. +207
    -0
      tests/ut/stub/cxx_api/status.cc
  39. +11
    -9
      tests/ut/stub/cxx_api/types.cc
  40. +86
    -15
      tests/ut/stub/include/api/context.h
  41. +1
    -4
      tests/ut/stub/include/api/data_type.h
  42. +26
    -0
      tests/ut/stub/include/api/dual_abi_helper.h
  43. +0
    -1
      tests/ut/stub/include/api/graph.h
  44. +8
    -1
      tests/ut/stub/include/api/model.h
  45. +9
    -1
      tests/ut/stub/include/api/serialization.h
  46. +45
    -19
      tests/ut/stub/include/api/status.h
  47. +35
    -22
      tests/ut/stub/include/api/types.h
  48. +1
    -0
      tests/ut/stub/include/utils/utils.h
  49. +1
    -1
      third_party/mindspore

+ 1
- 1
example/matmul_distributed/export_model/distributed_inference.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.


+ 1
- 1
example/matmul_distributed/export_model/net.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.


+ 3
- 3
mindspore_serving/ccsrc/common/grpc_client.h View File

@@ -39,7 +39,7 @@ namespace serving {
using PredictOnFinish = std::function<void()>;
using DispatchCallback = std::function<void(Status status)>;
using AsyncPredictCallback = std::function<void(Status status)>;
template <typename Request, typename Reply, typename MSStub>
class MSServiceClient {
@@ -80,7 +80,7 @@ class MSServiceClient {
}
}
void PredictAsync(const Request &request, Reply *reply, MSStub *stub, DispatchCallback callback) {
void PredictAsync(const Request &request, Reply *reply, MSStub *stub, AsyncPredictCallback callback) {
AsyncClientCall *call = new AsyncClientCall;
call->reply = reply;
call->callback = std::move(callback);
@@ -95,7 +95,7 @@ class MSServiceClient {
grpc::ClientContext context;
grpc::Status status;
Reply *reply;
DispatchCallback callback;
AsyncPredictCallback callback;
std::shared_ptr<grpc::ClientAsyncResponseReader<Reply>> response_reader;
};


+ 28
- 2
mindspore_serving/ccsrc/common/proto_tensor.cc View File

@@ -256,8 +256,17 @@ Status GrpcTensorHelper::CreateInstanceFromRequest(const proto::PredictRequest &
return SUCCESS;
}

Status GrpcTensorHelper::CreateReplyFromInstances(const proto::PredictRequest &request,
const vector<InstancePtr> &instances, proto::PredictReply *reply) {
void GrpcTensorHelper::CreateReplyFromInstances(const proto::PredictRequest &request,
const vector<InstancePtr> &instances, proto::PredictReply *reply) {
auto status = CreateReplyFromInstancesInner(request, instances, reply);
if (status != SUCCESS) {
CreateReplyFromErrorMsg(status, reply);
}
}

Status GrpcTensorHelper::CreateReplyFromInstancesInner(const proto::PredictRequest &request,
const std::vector<InstancePtr> &instances,
proto::PredictReply *reply) {
MSI_EXCEPTION_IF_NULL(reply);
*reply->mutable_servable_spec() = request.servable_spec();
if (instances.empty()) {
@@ -422,6 +431,23 @@ Status GrpcTensorHelper::CheckRequestTensor(const proto::Tensor &tensor) {
return SUCCESS;
}

void GrpcTensorHelper::CreateReplyFromErrorMsg(const Status &error_msg, proto::PredictReply *reply) {
MSI_EXCEPTION_IF_NULL(reply);
if (error_msg == SUCCESS) {
return;
}
reply->clear_error_msg();
reply->clear_instances();
auto proto_error_msg = reply->add_error_msg();
proto_error_msg->set_error_code(FAILED);
std::string error_msg_str = error_msg.StatusMessage();
if (error_msg_str.empty()) {
proto_error_msg->set_error_msg("Predict failed");
} else {
proto_error_msg->set_error_msg(error_msg_str);
}
}

serving::LogStream &operator<<(serving::LogStream &stream, proto::DataType data_type) {
const std::map<proto::DataType, std::string> type_name_map{
{proto::MS_UNKNOWN, "proto::MS_UNKNOWN"}, {proto::MS_BOOL, "proto::kMSI_Bool"},


+ 6
- 2
mindspore_serving/ccsrc/common/proto_tensor.h View File

@@ -67,8 +67,9 @@ class MS_API GrpcTensorHelper {
static void GetWorkerSpec(const proto::RemoveWorkerRequest &request, WorkerSpec *worker_spec);
static Status CreateInstanceFromRequest(const proto::PredictRequest &request, RequestSpec *request_spec,
std::vector<InstanceData> *results);
static Status CreateReplyFromInstances(const proto::PredictRequest &request,
const std::vector<InstancePtr> &instances, proto::PredictReply *reply);
static void CreateReplyFromInstances(const proto::PredictRequest &request, const std::vector<InstancePtr> &instances,
proto::PredictReply *reply);
static void CreateReplyFromErrorMsg(const Status &error_msg, proto::PredictReply *reply);
static void CopyFromAgentSpec(const proto::AgentSpec &request, WorkerAgentSpec *worker_specs);
static void CopyFromWorkerAgentSpec(const std::vector<WorkerAgentSpec> &worker_specs,
proto::AgentRegisterRequest *request);
@@ -78,6 +79,9 @@ class MS_API GrpcTensorHelper {
const std::vector<std::string> &input_names,
std::vector<InstanceData> *results);
static Status CheckRequestTensor(const proto::Tensor &tensor);

static Status CreateReplyFromInstancesInner(const proto::PredictRequest &request,
const std::vector<InstancePtr> &instances, proto::PredictReply *reply);
};

extern MS_API LogStream &operator<<(serving::LogStream &stream, proto::DataType data_type);


+ 39
- 17
mindspore_serving/ccsrc/master/dispacther.cc View File

@@ -55,25 +55,47 @@ DispatcherWorkerContext Dispatcher::GetWorkSession(const RequestSpec &request_sp
return context;
}

Status Dispatcher::Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply) {
void Dispatcher::Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply) {
MSI_EXCEPTION_IF_NULL(reply);
auto promise = std::make_shared<std::pair<std::promise<void>, Status>>(std::make_pair(std::promise<void>(), FAILED));
auto future = promise->first.get_future();
DispatchCallback callback = [promise](Status status) {
promise->second = status;
promise->first.set_value();
};
auto status = DispatchAsync(request, reply, callback);
auto promise = std::make_shared<std::promise<void>>();
auto future = promise->get_future();
PredictOnFinish on_finish = [promise]() { promise->set_value(); };
DispatchAsync(request, reply, on_finish);
future.get(); // wait callback finish
}

void Dispatcher::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply,
PredictOnFinish on_finish) {
MSI_EXCEPTION_IF_NULL(reply);
Status status;
(*reply->mutable_servable_spec()) = request.servable_spec();
try {
MSI_TIME_STAMP_START(Predict)
status = DispatchAsyncInner(request, reply, on_finish);
MSI_TIME_STAMP_END(Predict)
} catch (const std::bad_alloc &ex) {
MSI_LOG(ERROR) << "Serving Error: malloc memory failed";
std::cout << "Serving Error: malloc memory failed" << std::endl;
} catch (const std::runtime_error &ex) {
MSI_LOG(ERROR) << "Serving Error: runtime error occurred: " << ex.what();
std::cout << "Serving Error: runtime error occurred: " << ex.what() << std::endl;
} catch (const std::exception &ex) {
MSI_LOG(ERROR) << "Serving Error: exception occurred: " << ex.what();
std::cout << "Serving Error: exception occurred: " << ex.what() << std::endl;
} catch (...) {
MSI_LOG(ERROR) << "Serving Error: exception occurred";
std::cout << "Serving Error: exception occurred";
}
MSI_LOG(INFO) << "Finish call service Eval";

if (status != SUCCESS) {
MSI_LOG_ERROR << "DispatchAsync failed";
return status;
GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply);
on_finish();
}
future.get(); // wait callback finish
return promise->second;
}

Status Dispatcher::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply,
DispatchCallback callback) {
Status Dispatcher::DispatchAsyncInner(const proto::PredictRequest &request, proto::PredictReply *reply,
PredictOnFinish on_finish) {
MSI_EXCEPTION_IF_NULL(reply);
std::shared_lock<std::shared_mutex> lock(servable_shared_lock_);
RequestSpec request_spec;
@@ -88,7 +110,7 @@ Status Dispatcher::DispatchAsync(const proto::PredictRequest &request, proto::Pr
if (!find_method) {
return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Request " << request_spec.Repr() << ", method is not available";
}
return worker.notify_worker_->DispatchAsync(request, reply, std::move(callback));
return worker.notify_worker_->DispatchAsync(request, reply, std::move(on_finish));
}

Status Dispatcher::RegisterServableCommon(const std::vector<WorkerSpec> &worker_specs, CreateNotifyWorkerFunc func) {
@@ -216,7 +238,7 @@ Status Dispatcher::RegisterServable(const proto::RegisterRequest &request, proto
std::vector<WorkerSpec> worker_specs;
GrpcTensorHelper::GetWorkerSpec(request, &worker_specs);
auto create_notify_worker = [](const WorkerSpec &worker_spec) {
std::shared_ptr<BaseNotifyWorker> notify_worker = std::make_shared<GrpcNotfiyWorker>(worker_spec.worker_address);
std::shared_ptr<BaseNotifyWorker> notify_worker = std::make_shared<GrpcNotifyWorker>(worker_spec.worker_address);
return notify_worker;
};
return RegisterServableCommon(worker_specs, create_notify_worker);
@@ -232,7 +254,7 @@ Status Dispatcher::AddServable(const proto::AddWorkerRequest &request, proto::Ad
GrpcTensorHelper::GetWorkerSpec(request, &worker_spec);

auto create_notify_worker = [](const WorkerSpec &worker_spec) {
std::shared_ptr<BaseNotifyWorker> notify_worker = std::make_shared<GrpcNotfiyWorker>(worker_spec.worker_address);
std::shared_ptr<BaseNotifyWorker> notify_worker = std::make_shared<GrpcNotifyWorker>(worker_spec.worker_address);
return notify_worker;
};
return AddServableCommon(worker_spec, create_notify_worker);


+ 5
- 2
mindspore_serving/ccsrc/master/dispacther.h View File

@@ -40,8 +40,8 @@ class MS_API Dispatcher {
public:
Dispatcher();
~Dispatcher();
Status Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply);
Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, DispatchCallback callback);
void Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply);
void DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, PredictOnFinish on_finish);

Status RegisterServable(const proto::RegisterRequest &request, proto::RegisterReply *reply);
Status UnregisterServable(const proto::ExitRequest &request, proto::ExitReply *reply);
@@ -71,6 +71,9 @@ class MS_API Dispatcher {
Status UnregisterServableCommon(const std::string &worker_address);
Status AddServableCommon(const WorkerSpec &worker_spec, CreateNotifyWorkerFunc func);
Status RemoveServableCommon(const WorkerSpec &worker_spec);

Status DispatchAsyncInner(const proto::PredictRequest &request, proto::PredictReply *reply,
PredictOnFinish on_finish);
};

} // namespace mindspore::serving


+ 3
- 47
mindspore_serving/ccsrc/master/grpc/grpc_process.cc View File

@@ -36,53 +36,9 @@ std::string GetProtorWorkerSpecRepr(const proto::WorkerSpec &worker_spec) {
}
} // namespace

Status MSServiceImpl::PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply,
DispatchCallback callback) {
MSI_EXCEPTION_IF_NULL(request);
MSI_EXCEPTION_IF_NULL(reply);
Status status(FAILED);
auto on_status = [request, reply](Status status) {
if (status != SUCCESS) {
(*reply->mutable_servable_spec()) = request->servable_spec();
reply->clear_error_msg();
auto proto_error_msg = reply->add_error_msg();
proto_error_msg->set_error_code(FAILED);
std::string error_msg = status.StatusMessage();
if (error_msg.empty()) {
proto_error_msg->set_error_msg("Predict failed");
} else {
proto_error_msg->set_error_msg(error_msg);
}
}
};
DispatchCallback callback_with_status_handle = [callback, on_status](Status status) {
on_status(status);
callback(status);
};
try {
MSI_TIME_STAMP_START(Predict)
status = dispatcher_->DispatchAsync(*request, reply, callback_with_status_handle);
MSI_TIME_STAMP_END(Predict)
} catch (const std::bad_alloc &ex) {
MSI_LOG(ERROR) << "Serving Error: malloc memory failed";
std::cout << "Serving Error: malloc memory failed" << std::endl;
} catch (const std::runtime_error &ex) {
MSI_LOG(ERROR) << "Serving Error: runtime error occurred: " << ex.what();
std::cout << "Serving Error: runtime error occurred: " << ex.what() << std::endl;
} catch (const std::exception &ex) {
MSI_LOG(ERROR) << "Serving Error: exception occurred: " << ex.what();
std::cout << "Serving Error: exception occurred: " << ex.what() << std::endl;
} catch (...) {
MSI_LOG(ERROR) << "Serving Error: exception occurred";
std::cout << "Serving Error: exception occurred";
}
MSI_LOG(INFO) << "Finish call service Eval";

if (status != SUCCESS) {
on_status(status);
return status;
}
return SUCCESS;
void MSServiceImpl::PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply,
PredictOnFinish on_finish) {
dispatcher_->DispatchAsync(*request, reply, on_finish);
}

grpc::Status MSMasterImpl::Register(grpc::ServerContext *context, const proto::RegisterRequest *request,


+ 1
- 1
mindspore_serving/ccsrc/master/grpc/grpc_process.h View File

@@ -41,7 +41,7 @@ class MSServiceImpl {
explicit MSServiceImpl(std::shared_ptr<Dispatcher> dispatcher) : dispatcher_(dispatcher) {}
~MSServiceImpl() = default;

Status PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, DispatchCallback callback);
void PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, PredictOnFinish on_finish);

private:
std::shared_ptr<Dispatcher> dispatcher_;


+ 2
- 5
mindspore_serving/ccsrc/master/grpc/grpc_server.h View File

@@ -94,11 +94,8 @@ class MasterPredictContext : public MasterServiceContext {
void HandleRequest() override {
EnqueueRequest(service_impl_, async_service_, cq_);
state_ = STATE::FINISH;
DispatchCallback callback = [this](Status status) { responder_.Finish(response_, grpc::Status::OK, this); };
Status status = service_impl_->PredictAsync(&request_, &response_, callback);
if (!status.IsSuccess()) {
responder_.Finish(response_, grpc::Status::OK, this);
}
PredictOnFinish on_finish = [this]() { responder_.Finish(response_, grpc::Status::OK, this); };
service_impl_->PredictAsync(&request_, &response_, on_finish);
}
bool JudgeFinish() override { return state_ == STATE::FINISH; }


+ 1
- 1
mindspore_serving/ccsrc/master/notify_worker/base_notify.h View File

@@ -33,7 +33,7 @@ class MS_API BaseNotifyWorker {
virtual ~BaseNotifyWorker() = default;
virtual Status Exit() = 0;
virtual Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply,
DispatchCallback callback) = 0;
PredictOnFinish on_finish) = 0;
};

} // namespace serving


+ 10
- 5
mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc View File

@@ -20,19 +20,20 @@
#include <thread>
#include "common/exit_handle.h"
#include "common/grpc_server.h"
#include "common/proto_tensor.h"

namespace mindspore {
namespace serving {

GrpcNotfiyWorker::GrpcNotfiyWorker(const std::string &worker_address) {
GrpcNotifyWorker::GrpcNotifyWorker(const std::string &worker_address) {
worker_address_ = worker_address;
std::shared_ptr<grpc::Channel> channel = GrpcServer::CreateChannel(worker_address);
stub_ = proto::MSWorker::NewStub(channel);
}

GrpcNotfiyWorker::~GrpcNotfiyWorker() = default;
GrpcNotifyWorker::~GrpcNotifyWorker() = default;

Status GrpcNotfiyWorker::Exit() {
Status GrpcNotifyWorker::Exit() {
if (stub_) {
proto::ExitRequest request;
request.set_address(worker_address_);
@@ -47,8 +48,8 @@ Status GrpcNotfiyWorker::Exit() {
return SUCCESS;
}

Status GrpcNotfiyWorker::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply,
DispatchCallback callback) {
Status GrpcNotifyWorker::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply,
PredictOnFinish on_finish) {
if (!stub_) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Predict failed, worker gRPC has not been inited or has already exited, worker address "
@@ -58,6 +59,10 @@ Status GrpcNotfiyWorker::DispatchAsync(const proto::PredictRequest &request, pro
client_ = std::make_unique<MSPredictClient>();
client_->Start();
}
AsyncPredictCallback callback = [reply, on_finish](Status status) {
GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply);
on_finish();
};
client_->PredictAsync(request, reply, stub_.get(), callback);
return SUCCESS;
}


+ 4
- 4
mindspore_serving/ccsrc/master/notify_worker/grpc_notify.h View File

@@ -27,15 +27,15 @@
namespace mindspore {
namespace serving {

class MS_API GrpcNotfiyWorker : public BaseNotifyWorker {
class MS_API GrpcNotifyWorker : public BaseNotifyWorker {
public:
explicit GrpcNotfiyWorker(const std::string &worker_address);
~GrpcNotfiyWorker() override;
explicit GrpcNotifyWorker(const std::string &worker_address);
~GrpcNotifyWorker() override;

Status Exit() override;

Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply,
DispatchCallback callback) override;
PredictOnFinish on_finish) override;

private:
std::string worker_address_;


+ 2
- 2
mindspore_serving/ccsrc/master/notify_worker/local_notify.cc View File

@@ -26,8 +26,8 @@ Status LocalNotifyWorker::Exit() {
}

Status LocalNotifyWorker::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply,
DispatchCallback callback) {
return Worker::GetInstance().RunAsync(request, reply, callback);
PredictOnFinish on_finish) {
return Worker::GetInstance().RunAsync(request, reply, on_finish);
}

} // namespace serving


+ 1
- 1
mindspore_serving/ccsrc/master/notify_worker/local_notify.h View File

@@ -30,7 +30,7 @@ class MS_API LocalNotifyWorker : public BaseNotifyWorker {
Status Exit() override;

Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply,
DispatchCallback callback) override;
PredictOnFinish on_finish) override;
};

} // namespace serving


+ 28
- 53
mindspore_serving/ccsrc/master/restful/http_process.cc View File

@@ -693,14 +693,8 @@ Status RestfulService::RunRestful(const std::shared_ptr<RestfulRequest> &restful
}

MSI_TIME_STAMP_START(Predict)
status = dispatcher_->Dispatch(request, &reply);
dispatcher_->Dispatch(request, &reply);
MSI_TIME_STAMP_END(Predict)
if (status != SUCCESS) {
std::string error_msg = status.StatusMessage();
std::string msg = "Predict failed, " + error_msg;
status = msg;
return status;
}

MSI_TIME_STAMP_START(CreateReplyJson)
status = ParseReply(reply, out_json);
@@ -1037,11 +1031,6 @@ Status RestfulService::CheckReply(const ProtoTensor &pb_tensor) {
return status;
}

void RestfulService::ParseErrorMsg(const proto::ErrorMsg &error, json *const js) {
std::string str = error.error_msg();
*js = str;
}

// 5.Parse reply
Status RestfulService::ParseReply(const PredictReply &reply, json *const out_json) {
Status status(SUCCESS);
@@ -1059,56 +1048,42 @@ Status RestfulService::ParseReply(const PredictReply &reply, json *const out_jso
Status RestfulService::ParseInstancesReply(const PredictReply &reply, json *const out_json) {
Status status(SUCCESS);
auto error_size = reply.error_msg_size();
if (error_size != 0 && error_size != 1 && error_size != instances_nums_) {
auto reply_size = reply.instances().size();
if (error_size == 1 && reply_size == 0) {
(*out_json)[kErrorMsg] = reply.error_msg()[0].error_msg();
return SUCCESS;
}
if (error_size != 0 && error_size != instances_nums_) {
return INFER_STATUS_LOG_ERROR(FAILED) << "reply error size:" << error_size << " is not 0,1 or instances size";
}
if (reply_size != instances_nums_) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "reply size:" << reply_size << " is not matched request size:" << instances_nums_;
}

(*out_json)[kInstancesReply] = json();
json &instances_json = (*out_json)[kInstancesReply];

int32_t reply_num = reply.instances().size();
if (reply_num == 0) {
reply_num = error_size;
}
if (error_size == 0 && reply_num != instances_nums_) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "reply size:" << reply_num << " is not matched request size:" << instances_nums_;
}

for (int32_t i = 0; i < reply_num; i++) {
bool success_flag = true;
if (i < error_size) {
auto &cur_error = reply.error_msg().at(i);
success_flag = (cur_error.error_code() == 0);
for (int32_t i = 0; i < instances_nums_; i++) {
instances_json.push_back(json());
auto &instance = instances_json.back();
if (error_size != 0 && reply.error_msg()[i].error_code() != 0) {
instance[kErrorMsg] = reply.error_msg(i).error_msg();
continue;
}
auto &cur_instance = reply.instances(i);
auto &items = cur_instance.items();
if (items.empty()) {
return INFER_STATUS_LOG_ERROR(FAILED) << "reply instance items is empty";
}

if (success_flag) {
if (i >= reply.instances_size()) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "index:" << i << " is more than reply instances size:" << reply.instances_size();
}
auto &cur_instance = reply.instances(i);
auto &items = cur_instance.items();
if (items.empty()) {
return INFER_STATUS_LOG_ERROR(FAILED) << "reply instance items is empty";
}
instances_json.push_back(json());
auto &instance = instances_json.back();

for (auto &item : items) {
instance[item.first] = json();
auto &value_json = instance[item.first];
status = ParseReplyDetail(item.second, &value_json);
if (status != SUCCESS) {
return status;
}
for (auto &item : items) {
instance[item.first] = json();
auto &value_json = instance[item.first];
status = ParseReplyDetail(item.second, &value_json);
if (status != SUCCESS) {
return status;
}
} else {
instances_json.push_back(json());
auto &obj = instances_json.back();
obj[kErrorMsg] = json();
auto &js = obj[kErrorMsg];
ParseErrorMsg(reply.error_msg(i), &js);
}
}
return status;


+ 0
- 1
mindspore_serving/ccsrc/master/restful/http_process.h View File

@@ -98,7 +98,6 @@ class RestfulService {
Status ParseScalarData(const ProtoTensor &pb_tensor, bool is_bytes, size_t index, json *const js);
template <typename T>
bool IsString();
void ParseErrorMsg(const proto::ErrorMsg &error_msg, json *const js);

RequestType request_type_{kInvalidType};
InstancesType instances_type_{kInvalidWay};


+ 4
- 4
mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h View File

@@ -41,13 +41,13 @@ class MSDistributedImpl final : public MSWorkerImpl {
: MSWorkerImpl(server_address), servable_(servable) {}
~MSDistributedImpl() = default;
grpc::Status AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request,
proto::AgentRegisterReply *reply) override;
proto::AgentRegisterReply *reply);
grpc::Status AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request,
proto::AgentExitReply *reply) override;
proto::AgentExitReply *reply);
grpc::Status AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request,
proto::AgentFailedReply *reply) override;
proto::AgentFailedReply *reply);
grpc::Status AgentConfigAcquire(grpc::ServerContext *context, const proto::AgentConfigAcquireRequest *request,
proto::AgentConfigAcquireReply *reply) override;
proto::AgentConfigAcquireReply *reply);
private:
std::shared_ptr<DistributedServable> servable_;


+ 1
- 1
mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc View File

@@ -76,7 +76,7 @@ Status DistributedServable::PredictInner(const std::vector<TensorBasePtr> &input
auto msg_list = std::make_shared<std::vector<DistributedPredictMsg>>(rank_size);

for (size_t i = 0; i < rank_size; ++i) {
DispatchCallback callback = [msg_list, i](const Status &status) {
AsyncPredictCallback callback = [msg_list, i](const Status &status) {
msg_list->at(i).status = status;
msg_list->at(i).promise.set_value();
};


+ 1
- 1
mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h View File

@@ -33,7 +33,7 @@ class MS_API BaseNotifyAgent {
virtual ~BaseNotifyAgent() = default;
virtual Status Exit() = 0;
virtual Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply,
DispatchCallback callback) = 0;
AsyncPredictCallback callback) = 0;
};

} // namespace serving


+ 1
- 1
mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc View File

@@ -55,7 +55,7 @@ Status GrpcNotifyAgent::Exit() {
}

Status GrpcNotifyAgent::DispatchAsync(const proto::DistributedPredictRequest &request,
proto::DistributedPredictReply *reply, DispatchCallback callback) {
proto::DistributedPredictReply *reply, AsyncPredictCallback callback) {
if (!stub_) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Predict failed, agent gRPC has not been inited or has already exited, agent address " << agent_address_;


+ 1
- 1
mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h View File

@@ -35,7 +35,7 @@ class MS_API GrpcNotifyAgent : public BaseNotifyAgent {
Status Exit() override;

Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply,
DispatchCallback callback) override;
AsyncPredictCallback callback) override;

private:
std::string agent_address_;


+ 8
- 14
mindspore_serving/ccsrc/worker/grpc/worker_process.cc View File

@@ -16,6 +16,7 @@
#include "worker/grpc/worker_process.h"
#include "worker/worker.h"
#include "common/proto_tensor.h"
namespace mindspore {
namespace serving {
@@ -26,13 +27,13 @@ grpc::Status MSWorkerImpl::Exit(grpc::ServerContext *context, const proto::ExitR
return grpc::Status::OK;
}
grpc::Status MSWorkerImpl::PredictAsync(grpc::ServerContext *context, const proto::PredictRequest *request,
proto::PredictReply *reply, DispatchCallback callback) {
void MSWorkerImpl::PredictAsync(grpc::ServerContext *context, const proto::PredictRequest *request,
proto::PredictReply *reply, PredictOnFinish on_finish) {
Status status(FAILED);
MSI_LOG(INFO) << "Begin call service Eval";
try {
MSI_TIME_STAMP_START(Predict)
status = Worker::GetInstance().RunAsync(*request, reply, callback);
status = Worker::GetInstance().RunAsync(*request, reply, on_finish);
MSI_TIME_STAMP_END(Predict)
} catch (const std::bad_alloc &ex) {
MSI_LOG(ERROR) << "Serving Error: malloc memory failed";
@@ -49,19 +50,12 @@ grpc::Status MSWorkerImpl::PredictAsync(grpc::ServerContext *context, const prot
}
MSI_LOG(INFO) << "Finish call service Eval";
if (status == INVALID_INPUTS) {
auto proto_error_msg = reply->add_error_msg();
proto_error_msg->set_error_code(status.StatusCode());
proto_error_msg->set_error_msg(status.StatusMessage());
return grpc::Status::OK;
} else if (status != SUCCESS) {
auto proto_error_msg = reply->add_error_msg();
proto_error_msg->set_error_code(FAILED);
proto_error_msg->set_error_msg("Predict failed");
return grpc::Status::OK;
if (status != SUCCESS) {
GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply);
on_finish();
}
return grpc::Status::OK;
}
grpc::Status MSWorkerImpl::Ping(grpc::ServerContext *context, const proto::PingRequest *request,
proto::PingReply *reply) {
MSI_EXCEPTION_IF_NULL(request);


+ 6
- 6
mindspore_serving/ccsrc/worker/grpc/worker_process.h View File

@@ -35,7 +35,7 @@ namespace mindspore {
namespace serving {
// Service Implement
class MSWorkerImpl : public proto::MSWorker::Service {
class MSWorkerImpl {
public:
explicit MSWorkerImpl(const std::string server_address) {
if (!watcher_) {
@@ -43,11 +43,11 @@ class MSWorkerImpl : public proto::MSWorker::Service {
}
}
grpc::Status PredictAsync(grpc::ServerContext *context, const proto::PredictRequest *request,
proto::PredictReply *reply, DispatchCallback callback);
grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply) override;
grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply) override;
grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply) override;
void PredictAsync(grpc::ServerContext *context, const proto::PredictRequest *request, proto::PredictReply *reply,
PredictOnFinish on_finish);
grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply);
grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply);
grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply);
std::shared_ptr<Watcher<proto::MSAgent, proto::MSMaster>> watcher_;
};


+ 2
- 5
mindspore_serving/ccsrc/worker/grpc/worker_server.h View File

@@ -100,11 +100,8 @@ class WorkerPredictContext : public WorkerServiceContext {
void HandleRequest() override {
EnqueueRequest(service_impl_, async_service_, cq_);
state_ = STATE::FINISH;
DispatchCallback callback = [this](Status status) { responder_.Finish(response_, grpc::Status::OK, this); };
grpc::Status status = service_impl_->PredictAsync(&ctx_, &request_, &response_, callback);
if (!status.ok()) {
responder_.Finish(response_, status, this);
}
PredictOnFinish on_finish = [this]() { responder_.Finish(response_, grpc::Status::OK, this); };
service_impl_->PredictAsync(&ctx_, &request_, &response_, on_finish);
}
private:


+ 12
- 12
mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc View File

@@ -141,22 +141,22 @@ std::shared_ptr<Context> MindSporeModelWrap::TransformModelContext(const std::ma
MSI_LOG_ERROR << "Set model context output type failed, unknown data type " << val;
}
};
std::map<std::string, ContextStrFun> option_map = {
{"acl_option.insert_op_config_file_path", mindspore::ModelContext::SetInsertOpConfigPath},
{"acl_option.input_format", mindspore::ModelContext::SetInputFormat},
{"acl_option.input_shape", mindspore::ModelContext::SetInputShape},
{"acl_option.output_type", set_output_type},
{"acl_option.precision_mode", mindspore::ModelContext::SetPrecisionMode},
{"acl_option.op_select_impl_mode", mindspore::ModelContext::SetOpSelectImplMode},
};
auto context = std::make_shared<mindspore::ModelContext>();
for (auto &item : options) {
const auto &key = item.first;
const auto &value = item.second;
auto it = option_map.find(key);
if (it != option_map.end()) {
MSI_LOG_INFO << "Set context options, key: " << key << ", value: " << value;
it->second(context, value);
if (key == "acl_option.insert_op_config_file_path") {
mindspore::ModelContext::SetInsertOpConfigPath(context, value);
} else if (key == "acl_option.input_format") {
mindspore::ModelContext::SetInputFormat(context, value);
} else if (key == "acl_option.input_shape") {
mindspore::ModelContext::SetInputShape(context, value);
} else if (key == "acl_option.output_type") {
set_output_type(context, value);
} else if (key == "acl_option.precision_mode") {
mindspore::ModelContext::SetPrecisionMode(context, value);
} else if (key == "acl_option.op_select_impl_mode") {
mindspore::ModelContext::SetOpSelectImplMode(context, value);
}
}
return context;


+ 4
- 11
mindspore_serving/ccsrc/worker/worker.cc View File

@@ -60,7 +60,7 @@ Status Worker::RemoveWorker(const ServableWorkerContext &work) {
return notify_master_->RemoveWorker(work.worker_spec);
}

Status Worker::RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, DispatchCallback callback) {
Status Worker::RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, PredictOnFinish on_finish) {
std::shared_lock<std::shared_mutex> lock(worker_shared_lock_);
if (!servable_started_) {
return INFER_STATUS_LOG_ERROR(FAILED) << "RunAsync worker for inference failed, worker has not been started";
@@ -81,16 +81,9 @@ Status Worker::RunAsync(const proto::PredictRequest &request, proto::PredictRepl
if (worker.worker_service == nullptr) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Cannot find servable match " << request_spec.Repr();
}
WorkCallBack on_process_done = [request, reply, callback](const std::vector<InstancePtr> &instances) {
auto status = GrpcTensorHelper::CreateReplyFromInstances(request, instances, reply);
if (status != SUCCESS) {
MSI_LOG_ERROR << "transfer result to reply failed";
reply->clear_error_msg();
auto proto_error = reply->add_error_msg();
proto_error->set_error_code(status.StatusCode());
proto_error->set_error_msg(status.StatusMessage());
}
callback(SUCCESS);
WorkCallBack on_process_done = [request, reply, on_finish](const std::vector<InstancePtr> &instances) {
GrpcTensorHelper::CreateReplyFromInstances(request, instances, reply);
on_finish();
};
return worker.worker_service->Work(request_spec, instances_data, on_process_done);
}


+ 1
- 1
mindspore_serving/ccsrc/worker/worker.h View File

@@ -53,7 +53,7 @@ class MS_API Worker {
static Worker &GetInstance();
void Clear();

Status RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, DispatchCallback callback);
Status RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, PredictOnFinish on_finish);
Status StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master);

Status StartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server, const std::string &worker_ip,


+ 2
- 5
tests/ut/cpp/common/test_servable_common.h View File

@@ -376,11 +376,8 @@ class TestMasterWorkerClient : public TestMasterWorker {
grpc::ServerContext context;
auto promise = std::make_shared<std::promise<void>>();
auto future = promise->get_future();
DispatchCallback callback = [promise](Status status) { promise->set_value(); };
auto status = impl.PredictAsync(&request, reply, callback);
if (!status.IsSuccess()) {
return grpc::Status::OK;
}
PredictOnFinish callback = [promise]() { promise->set_value(); };
impl.PredictAsync(&request, reply, callback);
future.get();
return grpc::Status::OK;
}


+ 368
- 1
tests/ut/python/tests/test_mater_worker_client.py View File

@@ -173,7 +173,7 @@ def test_grpc_start_restful_server_twice_failed():


@serving_test
def test_grpc_alone_repeat_master_and_woker_port_failed():
def test_grpc_alone_repeat_master_and_worker_port_failed():
base = ServingTestBase()
base.init_servable(1, "add_servable_config.py")
master.start_master_server(master_port=7600)
@@ -731,3 +731,370 @@ def add_common(x1, x2):
assert result[0]["y"] == 0
assert "Preprocess Failed" in str(result[1]["error"]) or "Servable stopped" in str(result[1]["error"])
assert result[0]["y"] == 0


@serving_test
def test_servable_worker_with_master_preprocess_runtime_error():
# fail returned from Preprocess
base = ServingTestBase()
servable_content = servable_config_import
servable_content += servable_config_declare_servable
servable_content += r"""
index = 0
def preprocess(instances):
count = len(instances)
global index
for i in range(count):
ret = index
index += 1
if ret == 0:
raise RuntimeError("runtime error")
yield ret
@register.register_method(output_names=["y"])
def add_common(x1, x2):
x3 = register.call_preprocess_pipeline(preprocess, x1)
y = register.call_servable(x1, x2)
return x3
"""
base.init_servable_with_servable_config(1, servable_content)
worker.start_servable_in_master(base.servable_dir, base.servable_name)
master.start_grpc_server("0.0.0.0", 5500)
# Client
instance_count = 3

instances = []
y_data_list = []
for i in range(instance_count):
x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1)
x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1)
y_data_list.append(x1 + x2)
instances.append({"x1": x1, "x2": x2})

client = create_client("localhost", 5500, base.servable_name, "add_common")
result = client.infer(instances)
print(result)
assert "Preprocess Failed" in result[0]["error"]
assert result[1]["y"] == 1
assert result[2]["y"] == 2


@serving_test
def test_servable_worker_alone_preprocess_runtime_error():
# fail returned from Preprocess
base = ServingTestBase()
servable_content = servable_config_import
servable_content += servable_config_declare_servable
servable_content += r"""
index = 0
def preprocess(instances):
count = len(instances)
global index
for i in range(count):
ret = index
index += 1
if ret == 0:
raise RuntimeError("runtime error")
yield ret
@register.register_method(output_names=["y"])
def add_common(x1, x2):
x3 = register.call_preprocess_pipeline(preprocess, x1)
y = register.call_servable(x1, x2)
return x3
"""
base.init_servable_with_servable_config(1, servable_content)
master.start_master_server("0.0.0.0", 6100)
worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100)
master.start_grpc_server("0.0.0.0", 5500)
# Client
instance_count = 3

instances = []
y_data_list = []
for i in range(instance_count):
x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1)
x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1)
y_data_list.append(x1 + x2)
instances.append({"x1": x1, "x2": x2})

client = create_client("localhost", 5500, base.servable_name, "add_common")
result = client.infer(instances)
print(result)
assert "Preprocess Failed" in result[0]["error"]
assert result[1]["y"] == 1
assert result[2]["y"] == 2


@serving_test
def test_servable_worker_with_master_predict_check_failed():
# fail returned from Predict
base = ServingTestBase()
servable_content = servable_config_import
servable_content += servable_config_declare_servable
servable_content += r"""
@register.register_method(output_names=["y"])
def add_common(x1, x2):
y = register.call_servable(x1, x2)
return y
"""
base.init_servable_with_servable_config(1, servable_content)
worker.start_servable_in_master(base.servable_dir, base.servable_name)
master.start_grpc_server("0.0.0.0", 5500)
# Client
instance_count = 3

instances = []
y_data_list = []
for i in range(instance_count):
if i == 0:
x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1)
else:
x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1)
x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1)
y_data_list.append(x1 + x2)
instances.append({"x1": x1, "x2": x2})

client = create_client("localhost", 5500, base.servable_name, "add_common")
result = client.infer(instances)
print(result)
assert "Given model input 0 size 8 not match the size 16 defined in model" in result[0]["error"]
assert "y" in result[1]
assert "y" in result[2]


@serving_test
def test_servable_worker_alone_predict_check_failed():
# fail returned from Predict
base = ServingTestBase()
servable_content = servable_config_import
servable_content += servable_config_declare_servable
servable_content += r"""
@register.register_method(output_names=["y"])
def add_common(x1, x2):
y = register.call_servable(x1, x2)
return y
"""
base.init_servable_with_servable_config(1, servable_content)
master.start_master_server("0.0.0.0", 6100)
worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100)
master.start_grpc_server("0.0.0.0", 5500)
# Client
instance_count = 3

instances = []
y_data_list = []
for i in range(instance_count):
if i == 0:
x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1)
else:
x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1)
x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1)
y_data_list.append(x1 + x2)
instances.append({"x1": x1, "x2": x2})

client = create_client("localhost", 5500, base.servable_name, "add_common")
result = client.infer(instances)
print(result)
assert "Given model input 0 size 8 not match the size 16 defined in model" in result[0]["error"]
assert "y" in result[1]
assert "y" in result[2]


@serving_test
def test_servable_worker_with_master_postprocess_runtime_error():
# fail returned from Preprocess
base = ServingTestBase()
servable_content = servable_config_import
servable_content += servable_config_declare_servable
servable_content += r"""
index = 0
def postprocess(y):
global index
ret = index
index += 1
if ret == 0:
raise RuntimeError("runtime error")
return ret
@register.register_method(output_names=["y"])
def add_common(x1, x2):
y = register.call_servable(x1, x2)
y = register.call_postprocess(postprocess, y)
return y
"""
base.init_servable_with_servable_config(1, servable_content)
worker.start_servable_in_master(base.servable_dir, base.servable_name)
master.start_grpc_server("0.0.0.0", 5500)
# Client
instance_count = 3

instances = []
y_data_list = []
for i in range(instance_count):
x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1)
x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1)
y_data_list.append(x1 + x2)
instances.append({"x1": x1, "x2": x2})

client = create_client("localhost", 5500, base.servable_name, "add_common")
result = client.infer(instances)
print(result)
assert "Postprocess Failed" in result[0]["error"]
assert result[1]["y"] == 1
assert result[2]["y"] == 2


@serving_test
def test_servable_worker_alone_postprocess_runtime_error():
# fail returned from Preprocess
base = ServingTestBase()
servable_content = servable_config_import
servable_content += servable_config_declare_servable
servable_content += r"""
index = 0
def postprocess(y):
global index
ret = index
index += 1
if ret == 0:
raise RuntimeError("runtime error")
return ret
@register.register_method(output_names=["y"])
def add_common(x1, x2):
y = register.call_servable(x1, x2)
y = register.call_postprocess(postprocess, y)
return y
"""
base.init_servable_with_servable_config(1, servable_content)
master.start_master_server("0.0.0.0", 6100)
worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100)
master.start_grpc_server("0.0.0.0", 5500)
# Client
instance_count = 3

instances = []
y_data_list = []
for i in range(instance_count):
x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1)
x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1)
y_data_list.append(x1 + x2)
instances.append({"x1": x1, "x2": x2})

client = create_client("localhost", 5500, base.servable_name, "add_common")
result = client.infer(instances)
print(result)
assert "Postprocess Failed" in result[0]["error"]
assert result[1]["y"] == 1
assert result[2]["y"] == 2


@serving_test
def test_servable_worker_with_master_input_param_less():
# fail returned from Worker::RunAsync
base = ServingTestBase()
servable_content = servable_config_import
servable_content += servable_config_declare_servable
servable_content += servable_config_method_add_common
base.init_servable_with_servable_config(1, servable_content)
worker.start_servable_in_master(base.servable_dir, base.servable_name)
master.start_grpc_server("0.0.0.0", 5500)
# Client
instance_count = 3

instances = []
y_data_list = []
for i in range(instance_count):
x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1)
x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1)
y_data_list.append(x1 + x2)
instances.append({"x3": x1, "x2": x2})

client = create_client("localhost", 5500, base.servable_name, "add_common")
result = client.infer(instances)
print(result)
assert "Cannot find input x1 in instance input" in result["error"]


@serving_test
def test_servable_worker_alone_input_param_less():
# fail returned from Worker::RunAsync
base = ServingTestBase()
servable_content = servable_config_import
servable_content += servable_config_declare_servable
servable_content += servable_config_method_add_common
base.init_servable_with_servable_config(1, servable_content)
master.start_master_server("0.0.0.0", 6100)
worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100)
master.start_grpc_server("0.0.0.0", 5500)
# Client
instance_count = 3

instances = []
y_data_list = []
for i in range(instance_count):
x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1)
x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1)
y_data_list.append(x1 + x2)
instances.append({"x3": x1, "x2": x2})

client = create_client("localhost", 5500, base.servable_name, "add_common")
result = client.infer(instances)
print(result)
assert "Cannot find input x1 in instance input" in result["error"]


@serving_test
def test_servable_worker_with_master_servable_not_available():
# fail returned from Worker::RunAsync
base = ServingTestBase()
servable_content = servable_config_import
servable_content += servable_config_declare_servable
servable_content += servable_config_method_add_common
base.init_servable_with_servable_config(1, servable_content)
worker.start_servable_in_master(base.servable_dir, base.servable_name)
master.start_grpc_server("0.0.0.0", 5500)
# Client
instance_count = 3

instances = []
y_data_list = []
for i in range(instance_count):
x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1)
x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1)
y_data_list.append(x1 + x2)
instances.append({"x3": x1, "x2": x2})

client = create_client("localhost", 5500, base.servable_name + "error", "add_common")
result = client.infer(instances)
print(result)
assert "servable is not available" in result["error"]


@serving_test
def test_servable_worker_alone_servable_not_available():
# fail returned from Worker::RunAsync
base = ServingTestBase()
servable_content = servable_config_import
servable_content += servable_config_declare_servable
servable_content += servable_config_method_add_common
base.init_servable_with_servable_config(1, servable_content)
master.start_master_server("0.0.0.0", 6100)
worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100)
master.start_grpc_server("0.0.0.0", 5500)
# Client
instance_count = 3

instances = []
y_data_list = []
for i in range(instance_count):
x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1)
x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1)
y_data_list.append(x1 + x2)
instances.append({"x3": x1, "x2": x2})

client = create_client("localhost", 5500, base.servable_name + "error", "add_common")
result = client.infer(instances)
print(result)
assert "servable is not available" in result["error"]

+ 85
- 34
tests/ut/stub/cxx_api/context.cc View File

@@ -14,6 +14,9 @@
* limitations under the License.
*/
#include "include/api/context.h"
#include <any>
#include <map>
#include <type_traits>
#include "utils/log_adapter.h"

constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
@@ -28,18 +31,28 @@ constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";

namespace mindspore {
template <class T>
static T GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
auto iter = context->params.find(key);
if (iter == context->params.end()) {
return T();
struct Context::Data {
std::map<std::string, std::any> params;
};

Context::Context() : data(std::make_shared<Data>()) {}

template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
static const U &GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
static U empty_result;
if (context == nullptr || context->data == nullptr) {
return empty_result;
}
auto iter = context->data->params.find(key);
if (iter == context->data->params.end()) {
return empty_result;
}
const std::any &value = iter->second;
if (value.type() != typeid(T)) {
return T();
if (value.type() != typeid(U)) {
return empty_result;
}

return std::any_cast<T>(value);
return std::any_cast<const U &>(value);
}

std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
@@ -47,22 +60,31 @@ std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
return g_context;
}

void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
void GlobalContext::SetGlobalDeviceTarget(const std::vector<char> &device_target) {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
global_context->params[kGlobalContextDeviceTarget] = device_target;
if (global_context->data == nullptr) {
global_context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(global_context->data);
}
global_context->data->params[kGlobalContextDeviceTarget] = CharToString(device_target);
}

std::string GlobalContext::GetGlobalDeviceTarget() {
std::vector<char> GlobalContext::GetGlobalDeviceTargetChar() {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
return GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
const std::string &ref = GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
return StringToChar(ref);
}

void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
global_context->params[kGlobalContextDeviceID] = device_id;
if (global_context->data == nullptr) {
global_context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(global_context->data);
}
global_context->data->params[kGlobalContextDeviceID] = device_id;
}

uint32_t GlobalContext::GetGlobalDeviceID() {
@@ -71,39 +93,58 @@ uint32_t GlobalContext::GetGlobalDeviceID() {
return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
}

void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInsertOpCfgPath] = cfg_path;
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionInsertOpCfgPath] = CharToString(cfg_path);
}

std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
std::vector<char> ModelContext::GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
const std::string &ref = GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
return StringToChar(ref);
}

void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInputFormat] = format;
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionInputFormat] = CharToString(format);
}

std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
std::vector<char> ModelContext::GetInputFormatChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInputFormat);
const std::string &ref = GetValue<std::string>(context, kModelOptionInputFormat);
return StringToChar(ref);
}

void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionInputShape] = shape;
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionInputShape] = CharToString(shape);
}

std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
std::vector<char> ModelContext::GetInputShapeChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionInputShape);
const std::string &ref = GetValue<std::string>(context, kModelOptionInputShape);
return StringToChar(ref);
}

void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionOutputType] = output_type;
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionOutputType] = output_type;
}

enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) {
@@ -111,24 +152,34 @@ enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &contex
return GetValue<enum DataType>(context, kModelOptionOutputType);
}

void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionPrecisionMode] = precision_mode;
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionPrecisionMode] = CharToString(precision_mode);
}

std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
std::vector<char> ModelContext::GetPrecisionModeChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionPrecisionMode);
const std::string &ref = GetValue<std::string>(context, kModelOptionPrecisionMode);
return StringToChar(ref);
}

void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::string &op_select_impl_mode) {
const std::vector<char> &op_select_impl_mode) {
MS_EXCEPTION_IF_NULL(context);
context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode;
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[kModelOptionOpSelectImplMode] = CharToString(op_select_impl_mode);
}

std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
std::vector<char> ModelContext::GetOpSelectImplModeChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
return GetValue<std::string>(context, kModelOptionOpSelectImplMode);
const std::string &ref = GetValue<std::string>(context, kModelOptionOpSelectImplMode);
return StringToChar(ref);
}
} // namespace mindspore

tests/ut/stub/cxx_api/graph/ms/ms_graph_impl.cc → tests/ut/stub/cxx_api/graph/ascend/ascend_graph_impl.cc View File

@@ -13,38 +13,38 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/graph/ms/ms_graph_impl.h"
#include "cxx_api/graph/ascend/ascend_graph_impl.h"
#include <algorithm>
#include "include/api/context.h"
#include "cxx_api/factory.h"
#include "stub/graph_impl_stub.h"

namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, MsGraphImpl);
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl);

std::shared_ptr<GraphCell::GraphImpl> MsGraphImpl::graph_imp_stub_ = std::make_shared<GraphImplStubAdd>();
std::shared_ptr<GraphCell::GraphImpl> AscendGraphImpl::graph_imp_stub_ = std::make_shared<GraphImplStubAdd>();

MsGraphImpl::MsGraphImpl() {}
AscendGraphImpl::AscendGraphImpl() {}

MsGraphImpl::~MsGraphImpl() {}
AscendGraphImpl::~AscendGraphImpl() {}

std::vector<MSTensor> MsGraphImpl::GetInputs() {
std::vector<MSTensor> AscendGraphImpl::GetInputs() {
if (!graph_imp_stub_) {
return {};
}
return graph_imp_stub_->GetInputs();
}

std::vector<MSTensor> MsGraphImpl::GetOutputs() {
std::vector<MSTensor> AscendGraphImpl::GetOutputs() {
if (!graph_imp_stub_) {
return {};
}
return graph_imp_stub_->GetOutputs();
}

Status MsGraphImpl::Load() { return kSuccess; }
Status AscendGraphImpl::Load() { return kSuccess; }

Status MsGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
Status AscendGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
if (!graph_imp_stub_) {
return kMCFailed;
}

tests/ut/stub/cxx_api/graph/ms/ms_graph_impl.h → tests/ut/stub/cxx_api/graph/ascend/ascend_graph_impl.h View File

@@ -13,25 +13,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H
#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H
#include <functional>
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <utility>
#include <mutex>
#include "include/api/status.h"
#include "include/api/graph.h"
#include "cxx_api/graph/graph_impl.h"
#include "cxx_api/model/model_impl.h"

namespace mindspore {
class MsGraphImpl : public GraphCell::GraphImpl {

class AscendGraphImpl : public GraphCell::GraphImpl {
public:
MsGraphImpl();
~MsGraphImpl() override;
AscendGraphImpl();
~AscendGraphImpl() override;

Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Load() override;
@@ -43,4 +43,4 @@ class MsGraphImpl : public GraphCell::GraphImpl {
};

} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H

+ 25
- 3
tests/ut/stub/cxx_api/model/model.cc View File

@@ -17,8 +17,16 @@
#include "include/api/context.h"
#include "cxx_api/model/model_impl.h"
#include "cxx_api/factory.h"
#include "utils/utils.h"

namespace mindspore {
namespace {
const std::map<std::string, std::set<ModelType>> kSupportedModelMap = {
{kDeviceTypeAscend310, {kOM, kMindIR}},
{kDeviceTypeAscend910, {kMindIR}},
{kDeviceTypeGPU, {kMindIR}},
};
}
Status Model::Build() {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Build();
@@ -60,8 +68,22 @@ Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context>

Model::~Model() {}

bool Model::CheckModelSupport(const std::string &device_type, ModelType) {
return Factory<ModelImpl>::Instance().CheckModelSupport(device_type);
}
bool Model::CheckModelSupport(const std::vector<char> &device_type, ModelType model_type) {
std::string device_type_str = CharToString(device_type);
if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) {
return false;
}

auto first_iter = kSupportedModelMap.find(device_type_str);
if (first_iter == kSupportedModelMap.end()) {
return false;
}

auto secend_iter = first_iter->second.find(model_type);
if (secend_iter == first_iter->second.end()) {
return false;
}

return true;
}
} // namespace mindspore

+ 4
- 1
tests/ut/stub/cxx_api/model/model_impl.h View File

@@ -25,6 +25,7 @@
#include "include/api/model.h"
#include "include/api/graph.h"
#include "cxx_api/graph/graph_data.h"
#include "utils/utils.h"

namespace mindspore {
class ModelImpl {
@@ -63,7 +64,9 @@ class ModelImpl {
friend class Model;
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
void SetContext(const std::shared_ptr<Context> &model_context) {
model_context_ = std::make_shared<Context>(*model_context);
if (model_context != nullptr) {
model_context_ = std::make_shared<Context>(*model_context);
}
}
};
} // namespace mindspore


+ 5
- 0
tests/ut/stub/cxx_api/model/ms/ms_model.cc View File

@@ -66,6 +66,11 @@ Status MsModel::Build() {
MS_LOG(INFO) << "Start build model.";
MS_EXCEPTION_IF_NULL(graph_);

if (graph_cell_ != nullptr) {
MS_LOG(INFO) << "This model has been built, skip.";
return kSuccess;
}

auto func_graph = ModelImpl::GetFuncGraph();
MS_EXCEPTION_IF_NULL(func_graph);



+ 8
- 4
tests/ut/stub/cxx_api/serialization.cc View File

@@ -16,7 +16,7 @@
#include "include/api/serialization.h"
#include <fstream>
#include "cxx_api/graph/graph_data.h"
#include "utils/utils.h"
#include "utils/log_adapter.h"

namespace mindspore {
static Buffer ReadFile(const std::string &file) {
@@ -77,13 +77,17 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
}

Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
Buffer data = ReadFile(file);
Graph Serialization::LoadModel(const std::vector<char> &file, ModelType model_type) {
std::string file_path = CharToString(file);
Buffer data = ReadFile(file_path);
if (data.Data() == nullptr) {
MS_LOG(EXCEPTION) << "Read file " << file << " failed.";
MS_LOG(EXCEPTION) << "Read file " << file_path << " failed.";
}
if (model_type == kMindIR) {
auto anf_graph = std::make_shared<FuncGraph>();
if (anf_graph == nullptr) {
MS_LOG(EXCEPTION) << "Load model failed.";
}
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
} else if (model_type == kOM) {
return Graph(std::make_shared<Graph::GraphData>(data, kOM));


+ 207
- 0
tests/ut/stub/cxx_api/status.cc View File

@@ -0,0 +1,207 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "include/api/status.h"
#ifndef ENABLE_ANDROID
#include <thread>
#endif
#include <map>
#include <sstream>

namespace mindspore {
struct Status::Data {
enum StatusCode status_code = kSuccess;
std::string status_msg;
int line_of_code = -1;
std::string file_name;
std::string err_description;
};

Status::Status() : data_(std::make_shared<Data>()) {}

Status::Status(enum StatusCode status_code, const std::vector<char> &status_msg) : data_(std::make_shared<Data>()) {
if (data_ == nullptr) {
return;
}

data_->status_msg = CharToString(status_msg);
data_->status_code = status_code;
}

Status::Status(enum StatusCode code, int line_of_code, const char *file_name, const std::vector<char> &extra)
: data_(std::make_shared<Data>()) {
if (data_ == nullptr) {
return;
}
data_->status_code = code;
data_->line_of_code = line_of_code;
if (file_name != nullptr) {
data_->file_name = file_name;
}
data_->err_description = CharToString(extra);

std::ostringstream ss;
#ifndef ENABLE_ANDROID
ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(code) << ". ";
if (!data_->err_description.empty()) {
ss << data_->err_description;
}
ss << "\n";
#endif

ss << "Line of code : " << line_of_code << "\n";
if (file_name != nullptr) {
ss << "File : " << file_name << "\n";
}
data_->status_msg = ss.str();
}

enum StatusCode Status::StatusCode() const {
if (data_ == nullptr) {
return kSuccess;
}
return data_->status_code;
}

std::vector<char> Status::ToCString() const {
if (data_ == nullptr) {
return std::vector<char>();
}
return StringToChar(data_->status_msg);
}

int Status::GetLineOfCode() const {
if (data_ == nullptr) {
return -1;
}
return data_->line_of_code;
}

std::vector<char> Status::GetErrDescriptionChar() const {
if (data_ == nullptr) {
return std::vector<char>();
}
return StringToChar(data_->status_msg);
}

std::vector<char> Status::CodeAsCString(enum StatusCode c) {
static std::map<enum StatusCode, std::string> info_map = {{kSuccess, "No error occurs."},
// Core
{kCoreFailed, "Common error code."},
// MD
{kMDOutOfMemory, "Out of memory"},
{kMDShapeMisMatch, "Shape is incorrect"},
{kMDInterrupted, "Interrupted system call"},
{kMDNoSpace, "No space left on device"},
{kMDPyFuncException, "Exception thrown from PyFunc"},
{kMDDuplicateKey, "Duplicate key"},
{kMDPythonInterpreterFailure, ""},
{kMDTDTPushFailure, "Unexpected error"},
{kMDFileNotExist, "Unexpected error"},
{kMDProfilingError, "Error encountered while profiling"},
{kMDBoundingBoxOutOfBounds, "Unexpected error"},
{kMDBoundingBoxInvalidShape, "Unexpected error"},
{kMDSyntaxError, "Syntax error"},
{kMDTimeOut, "Unexpected error"},
{kMDBuddySpaceFull, "BuddySpace full"},
{kMDNetWorkError, "Network error"},
{kMDNotImplementedYet, "Unexpected error"},
{kMDUnexpectedError, "Unexpected error"},
// ME
{kMEFailed, "Common error code."},
{kMEInvalidInput, "Invalid input."},
// MC
{kMCFailed, "Common error code."},
{kMCDeviceError, "Device error."},
{kMCInvalidInput, "Invalid input."},
{kMCInvalidArgs, "Invalid arguments."},
// Lite
{kLiteError, "Common error code."},
{kLiteNullptr, "NULL pointer returned."},
{kLiteParamInvalid, "Invalid parameter."},
{kLiteNoChange, "No change."},
{kLiteSuccessExit, "No error but exit."},
{kLiteMemoryFailed, "Fail to create memory."},
{kLiteNotSupport, "Fail to support."},
{kLiteThreadPoolError, "Thread pool error."},
{kLiteOutOfTensorRange, "Failed to check range."},
{kLiteInputTensorError, "Failed to check input tensor."},
{kLiteReentrantError, "Exist executor running."},
{kLiteGraphFileError, "Failed to verify graph file."},
{kLiteNotFindOp, "Failed to find operator."},
{kLiteInvalidOpName, "Invalid operator name."},
{kLiteInvalidOpAttr, "Invalid operator attr."},
{kLiteOpExecuteFailure, "Failed to execution operator."},
{kLiteFormatError, "Failed to checking tensor format."},
{kLiteInferError, "Failed to infer shape."},
{kLiteInferInvalid, "Invalid infer shape before runtime."},
{kLiteInputParamInvalid, "Invalid input param by user."}};
auto iter = info_map.find(c);
return StringToChar(iter == info_map.end() ? "Unknown error" : iter->second);
}

std::ostream &operator<<(std::ostream &os, const Status &s) {
os << s.ToString();
return os;
}

std::vector<char> Status::SetErrDescription(const std::vector<char> &err_description) {
if (data_ == nullptr) {
return std::vector<char>();
}
data_->err_description = CharToString(err_description);
std::ostringstream ss;
#ifndef ENABLE_ANDROID
ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(data_->status_code) << ". ";
if (!data_->err_description.empty()) {
ss << data_->err_description;
}
ss << "\n";
#endif

if (data_->line_of_code > 0 && !data_->file_name.empty()) {
ss << "Line of code : " << data_->line_of_code << "\n";
ss << "File : " << data_->file_name << "\n";
}
data_->status_msg = ss.str();
return StringToChar(data_->status_msg);
}

bool Status::operator==(const Status &other) const {
if (data_ == nullptr && other.data_ == nullptr) {
return true;
}

if (data_ == nullptr || other.data_ == nullptr) {
return false;
}

return data_->status_code == other.data_->status_code;
}

bool Status::operator==(enum StatusCode other_code) const { return StatusCode() == other_code; }
bool Status::operator!=(const Status &other) const { return !operator==(other); }
bool Status::operator!=(enum StatusCode other_code) const { return !operator==(other_code); }

Status::operator bool() const { return (StatusCode() == kSuccess); }
Status::operator int() const { return static_cast<int>(StatusCode()); }

Status Status::OK() { return StatusCode::kSuccess; }
bool Status::IsOk() const { return (StatusCode() == StatusCode::kSuccess); }
bool Status::IsError() const { return !IsOk(); }
} // namespace mindspore

+ 11
- 9
tests/ut/stub/cxx_api/types.cc View File

@@ -133,10 +133,11 @@ class TensorReferenceImpl : public MSTensor::Impl {
std::vector<int64_t> shape_;
};

MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
MSTensor MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
std::string name_str = CharToString(name);
try {
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len);
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name_str, type, shape, data, data_len);
return MSTensor(impl);
} catch (const std::bad_alloc &) {
MS_LOG(ERROR) << "Malloc memory failed.";
@@ -147,10 +148,11 @@ MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, con
}
}

MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
MSTensor MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
std::string name_str = CharToString(name);
try {
std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name, type, shape, data, data_len);
std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name_str, type, shape, data, data_len);
return MSTensor(impl);
} catch (const std::bad_alloc &) {
MS_LOG(ERROR) << "Malloc memory failed.";
@@ -164,9 +166,9 @@ MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type,
MSTensor::MSTensor() : impl_(std::make_shared<TensorDefaultImpl>()) {}
MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {}
MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl); }
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: impl_(std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len)) {}
MSTensor::MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len)
: impl_(std::make_shared<TensorDefaultImpl>(CharToString(name), type, shape, data, data_len)) {}
MSTensor::~MSTensor() = default;

bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; }
@@ -178,9 +180,9 @@ MSTensor MSTensor::Clone() const {
return ret;
}

const std::string &MSTensor::Name() const {
std::vector<char> MSTensor::CharName() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->Name();
return StringToChar(impl_->Name());
}

enum DataType MSTensor::DataType() const {


+ 86
- 15
tests/ut/stub/include/api/context.h View File

@@ -16,49 +16,120 @@
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
#define MINDSPORE_INCLUDE_API_CONTEXT_H

#include <map>
#include <any>
#include <string>
#include <memory>
#include <vector>
#include "include/api/types.h"
#include "include/api/dual_abi_helper.h"

namespace mindspore {
constexpr auto kDeviceTypeAscend310 = "Ascend310";
constexpr auto kDeviceTypeAscend910 = "Ascend910";
constexpr auto kDeviceTypeGPU = "GPU";

struct MS_API Context {
public:
Context();
virtual ~Context() = default;
std::map<std::string, std::any> params;
struct Data;
std::shared_ptr<Data> data;
};

struct MS_API GlobalContext : public Context {
public:
static std::shared_ptr<Context> GetGlobalContext();

static void SetGlobalDeviceTarget(const std::string &device_target);
static std::string GetGlobalDeviceTarget();
static inline void SetGlobalDeviceTarget(const std::string &device_target);
static inline std::string GetGlobalDeviceTarget();

static void SetGlobalDeviceID(const uint32_t &device_id);
static uint32_t GetGlobalDeviceID();

private:
// api without std::string
static void SetGlobalDeviceTarget(const std::vector<char> &device_target);
static std::vector<char> GetGlobalDeviceTargetChar();
};

struct MS_API ModelContext : public Context {
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
static std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
public:
static inline void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
static inline std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);

static void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
static std::string GetInputFormat(const std::shared_ptr<Context> &context);
static inline void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
static inline std::string GetInputFormat(const std::shared_ptr<Context> &context);

static void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
static std::string GetInputShape(const std::shared_ptr<Context> &context);
static inline void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
static inline std::string GetInputShape(const std::shared_ptr<Context> &context);

static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type);
static enum DataType GetOutputType(const std::shared_ptr<Context> &context);

static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
static std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
static inline void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
static inline std::string GetPrecisionMode(const std::shared_ptr<Context> &context);

static inline void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::string &op_select_impl_mode);
static inline std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);

private:
// api without std::string
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path);
static std::vector<char> GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context);

static void SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format);
static std::vector<char> GetInputFormatChar(const std::shared_ptr<Context> &context);

static void SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape);
static std::vector<char> GetInputShapeChar(const std::shared_ptr<Context> &context);

static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode);
static std::vector<char> GetPrecisionModeChar(const std::shared_ptr<Context> &context);

static void SetOpSelectImplMode(const std::shared_ptr<Context> &context, const std::string &op_select_impl_mode);
static std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::vector<char> &op_select_impl_mode);
static std::vector<char> GetOpSelectImplModeChar(const std::shared_ptr<Context> &context);
};

void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
SetGlobalDeviceTarget(StringToChar(device_target));
}
std::string GlobalContext::GetGlobalDeviceTarget() { return CharToString(GetGlobalDeviceTargetChar()); }

void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
SetInsertOpConfigPath(context, StringToChar(cfg_path));
}
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
return CharToString(GetInsertOpConfigPathChar(context));
}

void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
SetInputFormat(context, StringToChar(format));
}
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
return CharToString(GetInputFormatChar(context));
}

void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
SetInputShape(context, StringToChar(shape));
}
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
return CharToString(GetInputShapeChar(context));
}

void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
SetPrecisionMode(context, StringToChar(precision_mode));
}
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
return CharToString(GetPrecisionModeChar(context));
}

void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::string &op_select_impl_mode) {
SetOpSelectImplMode(context, StringToChar(op_select_impl_mode));
}
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
return CharToString(GetOpSelectImplModeChar(context));
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H

+ 1
- 4
tests/ut/stub/include/api/data_type.h View File

@@ -1,7 +1,5 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_INCLUDE_API_DATA_TYPE_H_
#define MINDSPORE_INCLUDE_API_DATA_TYPE_H_



+ 26
- 0
tests/ut/stub/include/api/dual_abi_helper.h View File

@@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_
#define MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_

#include <string>
#include <vector>

namespace mindspore {
inline std::vector<char> StringToChar(const std::string &s) { return std::vector<char>(s.begin(), s.end()); }
inline std::string CharToString(const std::vector<char> &c) { return std::string(c.begin(), c.end()); }
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_

+ 0
- 1
tests/ut/stub/include/api/graph.h View File

@@ -17,7 +17,6 @@
#define MINDSPORE_INCLUDE_API_GRAPH_H

#include <cstddef>
#include <string>
#include <vector>
#include <map>
#include <memory>


+ 8
- 1
tests/ut/stub/include/api/model.h View File

@@ -25,6 +25,7 @@
#include "include/api/types.h"
#include "include/api/graph.h"
#include "include/api/cell.h"
#include "include/api/dual_abi_helper.h"

namespace mindspore {
class ModelImpl;
@@ -46,10 +47,16 @@ class MS_API Model {
std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs();

static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
static inline bool CheckModelSupport(const std::string &device_type, ModelType model_type);

private:
// api without std::string
static bool CheckModelSupport(const std::vector<char> &device_type, ModelType model_type);
std::shared_ptr<ModelImpl> impl_;
};

bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
return CheckModelSupport(StringToChar(device_type), model_type);
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H

+ 9
- 1
tests/ut/stub/include/api/serialization.h View File

@@ -24,16 +24,24 @@
#include "include/api/types.h"
#include "include/api/model.h"
#include "include/api/graph.h"
#include "include/api/dual_abi_helper.h"

namespace mindspore {
class MS_API Serialization {
public:
static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type);
static Graph LoadModel(const std::string &file, ModelType model_type);
inline static Graph LoadModel(const std::string &file, ModelType model_type);
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);

private:
static Graph LoadModel(const std::vector<char> &file, ModelType model_type);
};

Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
return LoadModel(StringToChar(file), model_type);
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H

+ 45
- 19
tests/ut/stub/include/api/status.h View File

@@ -16,9 +16,13 @@
#ifndef MINDSPORE_INCLUDE_API_STATUS_H
#define MINDSPORE_INCLUDE_API_STATUS_H

#include <memory>
#include <string>
#include <vector>
#include <ostream>
#include <climits>
#include "include/api/dual_abi_helper.h"
#include "include/api/types.h"

namespace mindspore {
enum CompCode : uint32_t {
@@ -100,39 +104,61 @@ enum StatusCode : uint32_t {
kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */
};

class Status {
class MS_API Status {
public:
Status() : status_code_(kSuccess) {}
Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit)
: status_code_(status_code), status_msg_(status_msg) {}
Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = "");
Status();
inline Status(enum StatusCode status_code, const std::string &status_msg = ""); // NOLINT(runtime/explicit)
inline Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = "");

~Status() = default;

enum StatusCode StatusCode() const { return status_code_; }
const std::string &ToString() const { return status_msg_; }
enum StatusCode StatusCode() const;
inline std::string ToString() const;

int GetLineOfCode() const;
inline std::string GetErrDescription() const;
inline std::string SetErrDescription(const std::string &err_description);

friend std::ostream &operator<<(std::ostream &os, const Status &s);

bool operator==(const Status &other) const { return status_code_ == other.status_code_; }
bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; }
bool operator!=(const Status &other) const { return status_code_ != other.status_code_; }
bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; }
bool operator==(const Status &other) const;
bool operator==(enum StatusCode other_code) const;
bool operator!=(const Status &other) const;
bool operator!=(enum StatusCode other_code) const;

explicit operator bool() const { return (status_code_ == kSuccess); }
explicit operator int() const { return static_cast<int>(status_code_); }
explicit operator bool() const;
explicit operator int() const;

static Status OK() { return Status(StatusCode::kSuccess); }
static Status OK();

bool IsOk() const { return (StatusCode() == StatusCode::kSuccess); }
bool IsOk() const;

bool IsError() const { return !IsOk(); }
bool IsError() const;

static std::string CodeAsString(enum StatusCode c);
static inline std::string CodeAsString(enum StatusCode c);

private:
enum StatusCode status_code_;
std::string status_msg_;
// api without std::string
explicit Status(enum StatusCode status_code, const std::vector<char> &status_msg);
Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::vector<char> &extra);
std::vector<char> ToCString() const;
std::vector<char> GetErrDescriptionChar() const;
std::vector<char> SetErrDescription(const std::vector<char> &err_description);
static std::vector<char> CodeAsCString(enum StatusCode c);

struct Data;
std::shared_ptr<Data> data_;
};

Status::Status(enum StatusCode status_code, const std::string &status_msg)
: Status(status_code, StringToChar(status_msg)) {}
Status::Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::string &extra)
: Status(code, line_of_code, file_name, StringToChar(extra)) {}
std::string Status::ToString() const { return CharToString(ToCString()); }
std::string Status::GetErrDescription() const { return CharToString(GetErrDescriptionChar()); }
std::string Status::SetErrDescription(const std::string &err_description) {
return CharToString(SetErrDescription(StringToChar(err_description)));
}
std::string Status::CodeAsString(enum StatusCode c) { return CharToString(CodeAsCString(c)); }
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_STATUS_H

+ 35
- 22
tests/ut/stub/include/api/types.h View File

@@ -21,22 +21,10 @@
#include <vector>
#include <memory>
#include "include/api/data_type.h"
#include "include/api/dual_abi_helper.h"

// refer to https://gcc.gnu.org/wiki/Visibility
#if defined _WIN32 || defined __CYGWIN__
#ifdef BUILDING_DLL
#ifdef __GNUC__
#define MS_API __attribute__((dllexport))
#else
#define MS_API __declspec(dllexport) // Note: actually gcc seems to also supports this syntax.
#endif
#else
#ifdef __GNUC__
#define MS_API __attribute__((dllimport))
#else
#define MS_API __declspec(dllimport) // Note: actually gcc seems to also supports this syntax.
#endif
#endif
#ifdef _WIN32
#define MS_API __declspec(dllexport)
#else
#define MS_API __attribute__((visibility("default")))
#endif
@@ -55,18 +43,18 @@ class MS_API MSTensor {
public:
class Impl;

static MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static inline MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static inline MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;

MSTensor();
explicit MSTensor(const std::shared_ptr<Impl> &impl);
MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
inline MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
~MSTensor();

const std::string &Name() const;
inline std::string Name() const;
enum DataType DataType() const;
const std::vector<int64_t> &Shape() const;
int64_t ElementNum() const;
@@ -81,6 +69,15 @@ class MS_API MSTensor {
bool operator==(std::nullptr_t) const;

private:
// api without std::string
static MSTensor CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static MSTensor CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
std::vector<char> CharName() const;

friend class ModelImpl;
explicit MSTensor(std::nullptr_t);
std::shared_ptr<Impl> impl_;
@@ -123,5 +120,21 @@ class MS_API Buffer {
class Impl;
std::shared_ptr<Impl> impl_;
};

MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
return CreateTensor(StringToChar(name), type, shape, data, data_len);
}

MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
return CreateRefTensor(StringToChar(name), type, shape, data, data_len);
}

MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: MSTensor(StringToChar(name), type, shape, data, data_len) {}

std::string MSTensor::Name() const { return CharToString(CharName()); }
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_TYPES_H

+ 1
- 0
tests/ut/stub/include/utils/utils.h View File

@@ -21,6 +21,7 @@
#include <atomic>
#include <string>
#include <vector>
#include <set>
#include "utils/log_adapter.h"

namespace mindspore {


+ 1
- 1
third_party/mindspore

@@ -1 +1 @@
Subproject commit 52fac12367131ec57e87ba757e42fc25479f433a
Subproject commit dd22b5ea7106baf494704be04e2dbaad6887f0ab

Loading…
Cancel
Save