TensorFlow Serving C++ API Documentation
simple_servers_test.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/servables/tensorflow/simple_servers.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include <gtest/gtest.h>
24 #include "tensorflow/cc/saved_model/loader.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/io/path.h"
31 #include "tensorflow/core/platform/env.h"
32 #include "tensorflow/core/public/session.h"
33 #include "tensorflow_serving/core/servable_handle.h"
34 #include "tensorflow_serving/test_util/test_util.h"
35 
36 namespace tensorflow {
37 namespace serving {
38 namespace {
39 
40 class SimpleServersTest : public ::testing::Test {
41  protected:
42  SimpleServersTest()
43  : test_data_path_(test_util::TensorflowTestSrcDirPath(
44  "cc/saved_model/testdata/half_plus_two")) {}
45 
46  // Test that a SavedModelBundle handles a single request for the half plus two
47  // model properly. The request has size=2, for batching purposes.
48  void TestSingleRequest(const SavedModelBundle& bundle) {
49  const Tensor input = test::AsTensor<float>({100.0f, 42.0f}, {2});
50  // half plus two: output should be input / 2 + 2.
51  const Tensor expected_output =
52  test::AsTensor<float>({100.0f / 2 + 2, 42.0f / 2 + 2}, {2});
53 
54  // Note that "x" and "y" are the actual names of the nodes in the graph.
55  // The saved manifest binds these to "input" and "output" respectively, but
56  // these tests are focused on the raw underlying session without bindings.
57  const std::vector<std::pair<string, Tensor>> inputs = {{"x", input}};
58  const std::vector<string> output_names = {"y"};
59  const std::vector<string> empty_targets;
60  std::vector<Tensor> outputs;
61 
62  TF_ASSERT_OK(
63  bundle.session->Run(inputs, output_names, empty_targets, &outputs));
64 
65  ASSERT_EQ(1, outputs.size());
66  const auto& single_output = outputs.at(0);
67  test::ExpectTensorEqual<float>(expected_output, single_output);
68  }
69 
70  // Test data path, to be initialized to point at an export of half-plus-two.
71  const string test_data_path_;
72 };
73 
74 TEST_F(SimpleServersTest, Basic) {
75  std::unique_ptr<Manager> manager;
76  const Status status = simple_servers::CreateSingleTFModelManagerFromBasePath(
77  test_data_path_, &manager);
78  TF_CHECK_OK(status);
79  // We wait until the manager starts serving the servable.
80  // TODO(b/25545570): Use the waiter api when it's ready.
81  while (manager->ListAvailableServableIds().empty()) {
82  Env::Default()->SleepForMicroseconds(1000);
83  }
84  ServableHandle<SavedModelBundle> bundle;
85  const Status handle_status =
86  manager->GetServableHandle(ServableRequest::Latest("default"), &bundle);
87  TF_CHECK_OK(handle_status);
88  TestSingleRequest(*bundle);
89 }
90 
91 } // namespace
92 } // namespace serving
93 } // namespace tensorflow