|
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include <grpcpp/grpcpp.h>
- #include <iostream>
- #include <vector>
- #include <string>
- #include <fstream>
- #include "./ms_service.grpc.pb.h"
-
- using grpc::Channel;
- using grpc::ClientContext;
- using grpc::Status;
- using ms_serving::MSService;
- using ms_serving::PredictReply;
- using ms_serving::PredictRequest;
- using ms_serving::Tensor;
- using ms_serving::TensorShape;
-
- class MSClient {
- public:
- explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
-
- ~MSClient() = default;
-
- std::string Predict() {
- // Data we are sending to the server.
- PredictRequest request;
-
- Tensor data;
- TensorShape shape;
- shape.add_dims(4);
- *data.mutable_tensor_shape() = shape;
- data.set_tensor_type(ms_serving::MS_FLOAT32);
- std::vector<float> input_data{1, 2, 3, 4};
- data.set_data(input_data.data(), input_data.size() * sizeof(float));
- *request.add_data() = data;
- *request.add_data() = data;
- std::cout << "intput tensor size is " << request.data_size() << std::endl;
- // Container for the data we expect from the server.
- PredictReply reply;
-
- // Context for the client. It could be used to convey extra information to
- // the server and/or tweak certain RPC behaviors.
- ClientContext context;
-
- // The actual RPC.
- Status status = stub_->Predict(&context, request, &reply);
- std::cout << "Compute [1, 2, 3, 4] + [1, 2, 3, 4]" << std::endl;
-
- // Act upon its status.
- if (status.ok()) {
- std::cout << "Add result is";
- for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) {
- std::cout << " " << (reinterpret_cast<const float *>(reply.mutable_result(0)->mutable_data()->data()))[i];
- }
- std::cout << std::endl;
- return "RPC OK";
- } else {
- std::cout << status.error_code() << ": " << status.error_message() << std::endl;
- return "RPC failed";
- }
- }
-
- private:
- std::unique_ptr<MSService::Stub> stub_;
- };
-
- int main(int argc, char **argv) {
- // Instantiate the client. It requires a channel, out of which the actual RPCs
- // are created. This channel models a connection to an endpoint specified by
- // the argument "--target=" which is the only expected argument.
- // We indicate that the channel isn't authenticated (use of
- // InsecureChannelCredentials()).
- std::string target_str;
- std::string arg_target_str("--target");
- if (argc > 1) {
- // parse target
- std::string arg_val = argv[1];
- size_t start_pos = arg_val.find(arg_target_str);
- if (start_pos != std::string::npos) {
- start_pos += arg_target_str.size();
- if (arg_val[start_pos] == '=') {
- target_str = arg_val.substr(start_pos + 1);
- } else {
- std::cout << "The only correct argument syntax is --target=" << std::endl;
- return 0;
- }
- } else {
- target_str = "localhost:5500";
- }
- } else {
- target_str = "localhost:5500";
- }
- MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
- std::string reply = client.Predict();
- std::cout << "client received: " << reply << std::endl;
-
- return 0;
- }
|