TensorFlow Serving C++ API Documentation
server.cc
1 /* Copyright 2016 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/model_servers/server.h"
17 
18 #include <unistd.h>
19 
20 #include <iostream>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "google/protobuf/wrappers.pb.h"
27 #include "grpc/grpc.h"
28 #include "grpcpp/health_check_service_interface.h"
29 #include "grpcpp/resource_quota.h"
30 #include "grpcpp/security/server_credentials.h"
31 #include "grpcpp/server_builder.h"
32 #include "grpcpp/server_context.h"
33 #include "grpcpp/support/status.h"
34 #include "absl/memory/memory.h"
35 #include "tensorflow/c/c_api.h"
36 #include "tensorflow/cc/saved_model/tag_constants.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/core/lib/strings/numbers.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/platform/env.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
44 #include "tensorflow/core/protobuf/config.pb.h"
45 #include "tsl/platform/errors.h"
46 #include "tensorflow_serving/config/model_server_config.pb.h"
47 #include "tensorflow_serving/config/monitoring_config.pb.h"
48 #include "tensorflow_serving/config/platform_config.pb.h"
49 #include "tensorflow_serving/config/ssl_config.pb.h"
50 #include "tensorflow_serving/core/availability_preserving_policy.h"
51 #include "tensorflow_serving/model_servers/grpc_status_util.h"
52 #include "tensorflow_serving/model_servers/model_platform_types.h"
53 #include "tensorflow_serving/model_servers/server_core.h"
54 #include "tensorflow_serving/model_servers/server_init.h"
55 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
56 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory_config.pb.h"
57 #include "tensorflow_serving/servables/tensorflow/util.h"
58 #include "tensorflow_serving/util/proto_util.h"
59 
60 namespace tensorflow {
61 namespace serving {
62 namespace main {
63 
64 namespace {
65 
66 tensorflow::Status LoadCustomModelConfig(
67  const ::google::protobuf::Any& any,
68  EventBus<ServableState>* servable_event_bus,
69  UniquePtrWithDeps<AspiredVersionsManager>* manager) {
70  LOG(FATAL) // Crash ok
71  << "ModelServer does not yet support custom model config.";
72 }
73 
74 ModelServerConfig BuildSingleModelConfig(const string& model_name,
75  const string& model_base_path) {
76  ModelServerConfig config;
77  LOG(INFO) << "Building single TensorFlow model file config: "
78  << " model_name: " << model_name
79  << " model_base_path: " << model_base_path;
80  tensorflow::serving::ModelConfig* single_model =
81  config.mutable_model_config_list()->add_config();
82  single_model->set_name(model_name);
83  single_model->set_base_path(model_base_path);
84  single_model->set_model_platform(
85  tensorflow::serving::kTensorFlowModelPlatform);
86  return config;
87 }
88 
89 
90 // gRPC Channel Arguments to be passed from command line to gRPC ServerBuilder.
91 struct GrpcChannelArgument {
92  string key;
93  string value;
94 };
95 
96 // Parses a comma separated list of gRPC channel arguments into list of
97 // ChannelArgument.
98 std::vector<GrpcChannelArgument> parseGrpcChannelArgs(
99  const string& channel_arguments_str) {
100  const std::vector<string> channel_arguments =
101  tensorflow::str_util::Split(channel_arguments_str, ",");
102  std::vector<GrpcChannelArgument> result;
103  for (const string& channel_argument : channel_arguments) {
104  const std::vector<string> key_val =
105  tensorflow::str_util::Split(channel_argument, "=");
106  result.push_back({key_val[0], key_val[1]});
107  }
108  return result;
109 }
110 
111 // If 'use_alts_credentials', build secure server credentials using ALTS.
112 // Else if 'ssl_config_file' is non-empty, build using ssl.
113 // Otherwise use insecure channel.
114 std::shared_ptr<::grpc::ServerCredentials> BuildServerCredentials(
115  bool use_alts_credentials, const string& ssl_config_file) {
116  if (use_alts_credentials) {
117  LOG(INFO) << "Using ALTS credentials";
118  ::grpc::experimental::AltsServerCredentialsOptions alts_opts;
119  return ::grpc::experimental::AltsServerCredentials(alts_opts);
120  } else if (ssl_config_file.empty()) {
121  LOG(INFO) << "Using InsecureServerCredentials";
122  return ::grpc::InsecureServerCredentials();
123  }
124 
125  SSLConfig ssl_config;
126  TF_CHECK_OK(ParseProtoTextFile<SSLConfig>(ssl_config_file, &ssl_config));
127  LOG(INFO) << "Using SSL credentials";
128 
129  ::grpc::SslServerCredentialsOptions ssl_ops(
130  ssl_config.client_verify()
131  ? GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY
132  : GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE);
133 
134  ssl_ops.force_client_auth = ssl_config.client_verify();
135 
136  if (!ssl_config.custom_ca().empty()) {
137  ssl_ops.pem_root_certs = ssl_config.custom_ca();
138  }
139 
140  ::grpc::SslServerCredentialsOptions::PemKeyCertPair keycert = {
141  ssl_config.server_key(), ssl_config.server_cert()};
142 
143  ssl_ops.pem_key_cert_pairs.push_back(keycert);
144 
145  return ::grpc::SslServerCredentials(ssl_ops);
146 }
147 
148 } // namespace
149 
150 Server::Options::Options()
151  : model_name("default"),
152  saved_model_tags(tensorflow::kSavedModelTagServe) {}
153 
154 Server::~Server() {
155  // Note: Deletion of 'fs_polling_thread_' will block until our underlying
156  // thread closure stops. Hence, destruction of this object will not proceed
157  // until the thread has terminated.
158  fs_config_polling_thread_.reset();
159  WaitForTermination();
160 }
161 
162 void Server::PollFilesystemAndReloadConfig(const string& config_file_path) {
163  ModelServerConfig config;
164  const Status read_status =
165  ParseProtoTextFile<ModelServerConfig>(config_file_path, &config);
166  if (!read_status.ok()) {
167  LOG(ERROR) << "Failed to read ModelServerConfig file: "
168  << read_status.message();
169  return;
170  }
171 
172  const Status reload_status = server_core_->ReloadConfig(config);
173  if (!reload_status.ok()) {
174  LOG(ERROR) << "PollFilesystemAndReloadConfig failed to ReloadConfig: "
175  << reload_status.message();
176  }
177 }
178 
179 Status Server::BuildAndStart(const Options& server_options) {
180  if (server_options.grpc_port == 0 &&
181  server_options.grpc_socket_path.empty()) {
182  return errors::InvalidArgument(
183  "At least one of server_options.grpc_port or "
184  "server_options.grpc_socket_path must be set.");
185  }
186 
187  if (server_options.use_alts_credentials &&
188  !server_options.ssl_config_file.empty()) {
189  return errors::InvalidArgument(
190  "Either use_alts_credentials must be false or "
191  "ssl_config_file must be empty.");
192  }
193 
194  if (server_options.model_base_path.empty() &&
195  server_options.model_config_file.empty()) {
196  return errors::InvalidArgument(
197  "Both server_options.model_base_path and "
198  "server_options.model_config_file are empty!");
199  }
200 
201  SetSignatureMethodNameCheckFeature(
202  server_options.enable_signature_method_name_check);
203 
204  // For ServerCore Options, we leave servable_state_monitor_creator unspecified
205  // so the default servable_state_monitor_creator will be used.
206  ServerCore::Options options;
207 
208  // model server config
209  if (server_options.model_config_file.empty()) {
210  options.model_server_config = BuildSingleModelConfig(
211  server_options.model_name, server_options.model_base_path);
212  } else {
213  TF_RETURN_IF_ERROR(ParseProtoTextFile<ModelServerConfig>(
214  server_options.model_config_file, &options.model_server_config));
215  }
216 
217  auto* tf_serving_registry =
218  init::TensorflowServingFunctionRegistration::GetRegistry();
219 
220  if (server_options.platform_config_file.empty()) {
221  SessionBundleConfig session_bundle_config;
222  // Batching config
223  if (server_options.enable_batching) {
224  BatchingParameters* batching_parameters =
225  session_bundle_config.mutable_batching_parameters();
226  if (server_options.batching_parameters_file.empty()) {
227  batching_parameters->mutable_thread_pool_name()->set_value(
228  "model_server_batch_threads");
229  } else {
230  TF_RETURN_IF_ERROR(ParseProtoTextFile<BatchingParameters>(
231  server_options.batching_parameters_file, batching_parameters));
232  }
233  if (server_options.enable_per_model_batching_params) {
234  session_bundle_config.set_enable_per_model_batching_params(true);
235  }
236  } else if (!server_options.batching_parameters_file.empty()) {
237  return errors::InvalidArgument(
238  "server_options.batching_parameters_file is set without setting "
239  "server_options.enable_batching to true.");
240  }
241 
242  if (!server_options.tensorflow_session_config_file.empty()) {
243  TF_RETURN_IF_ERROR(
244  ParseProtoTextFile(server_options.tensorflow_session_config_file,
245  session_bundle_config.mutable_session_config()));
246  }
247 
248  session_bundle_config.mutable_session_config()
249  ->mutable_gpu_options()
250  ->set_per_process_gpu_memory_fraction(
251  server_options.per_process_gpu_memory_fraction);
252 
253  if (server_options.tensorflow_intra_op_parallelism > 0 &&
254  server_options.tensorflow_inter_op_parallelism > 0 &&
255  server_options.tensorflow_session_parallelism > 0){
256  return errors::InvalidArgument("Either configure "
257  "server_options.tensorflow_session_parallelism "
258  "or (server_options.tensorflow_intra_op_parallelism, "
259  "server_options.tensorflow_inter_op_parallelism) separately. "
260  "You cannot configure all.");
261  } else if (server_options.tensorflow_intra_op_parallelism > 0 ||
262  server_options.tensorflow_inter_op_parallelism > 0){
263  session_bundle_config.mutable_session_config()
264  ->set_intra_op_parallelism_threads(
265  server_options.tensorflow_intra_op_parallelism);
266  session_bundle_config.mutable_session_config()
267  ->set_inter_op_parallelism_threads(
268  server_options.tensorflow_inter_op_parallelism);
269  } else {
270  session_bundle_config.mutable_session_config()
271  ->set_intra_op_parallelism_threads(
272  server_options.tensorflow_session_parallelism);
273  session_bundle_config.mutable_session_config()
274  ->set_inter_op_parallelism_threads(
275  server_options.tensorflow_session_parallelism);
276  }
277 
278  const std::vector<string> tags =
279  tensorflow::str_util::Split(server_options.saved_model_tags, ",");
280  for (const string& tag : tags) {
281  *session_bundle_config.add_saved_model_tags() = tag;
282  }
283  session_bundle_config.set_enable_model_warmup(
284  server_options.enable_model_warmup);
285  if (server_options.num_request_iterations_for_warmup > 0) {
286  session_bundle_config.mutable_model_warmup_options()
287  ->mutable_num_request_iterations()
288  ->set_value(server_options.num_request_iterations_for_warmup);
289  }
290  session_bundle_config.set_remove_unused_fields_from_bundle_metagraph(
291  server_options.remove_unused_fields_from_bundle_metagraph);
292  session_bundle_config.set_prefer_tflite_model(
293  server_options.prefer_tflite_model);
294  session_bundle_config.set_num_tflite_interpreters_per_pool(
295  server_options.num_tflite_interpreters_per_pool);
296  session_bundle_config.set_num_tflite_pools(server_options.num_tflite_pools);
297  session_bundle_config.set_mixed_precision(server_options.mixed_precision);
298 
299  TF_RETURN_IF_ERROR(tf_serving_registry->GetSetupPlatformConfigMap()(
300  session_bundle_config, options.platform_config_map));
301  } else {
302  TF_RETURN_IF_ERROR(ParseProtoTextFile<PlatformConfigMap>(
303  server_options.platform_config_file, &options.platform_config_map));
304  TF_RETURN_IF_ERROR(tf_serving_registry->GetUpdatePlatformConfigMap()(
305  options.platform_config_map));
306  }
307 
308  options.custom_model_config_loader = &LoadCustomModelConfig;
309  options.aspired_version_policy =
310  std::unique_ptr<AspiredVersionPolicy>(new AvailabilityPreservingPolicy);
311  options.num_load_threads = server_options.num_load_threads;
312  options.num_unload_threads = server_options.num_unload_threads;
313  options.max_num_load_retries = server_options.max_num_load_retries;
314  options.load_retry_interval_micros =
315  server_options.load_retry_interval_micros;
316  options.file_system_poll_wait_seconds =
317  server_options.file_system_poll_wait_seconds;
318  options.flush_filesystem_caches = server_options.flush_filesystem_caches;
319  options.allow_version_labels_for_unavailable_models =
320  server_options.allow_version_labels_for_unavailable_models;
321  options.force_allow_any_version_labels_for_unavailable_models =
322  server_options.force_allow_any_version_labels_for_unavailable_models;
323  options.enable_cors_support = server_options.enable_cors_support;
324 
325  TF_RETURN_IF_ERROR(ServerCore::Create(std::move(options), &server_core_));
326 
327  // Model config polling thread must be started after the call to
328  // ServerCore::Create() to prevent config reload being done concurrently from
329  // Create() and the poll thread.
330  if (server_options.fs_model_config_poll_wait_seconds > 0 &&
331  !server_options.model_config_file.empty()) {
332  PeriodicFunction::Options pf_options;
333  pf_options.thread_name_prefix = "Server_fs_model_config_poll_thread";
334 
335  const string model_config_file = server_options.model_config_file;
336  fs_config_polling_thread_.reset(new PeriodicFunction(
337  [this, model_config_file] {
338  this->PollFilesystemAndReloadConfig(model_config_file);
339  },
340  server_options.fs_model_config_poll_wait_seconds *
341  tensorflow::EnvTime::kSecondsToMicros,
342  pf_options));
343  }
344 
345  // 0.0.0.0" is the way to listen on localhost in gRPC.
346  const string server_address =
347  "0.0.0.0:" + std::to_string(server_options.grpc_port);
348  model_service_ = absl::make_unique<ModelServiceImpl>(server_core_.get());
349 
350  PredictionServiceOptions predict_server_options;
351  predict_server_options.server_core = server_core_.get();
352  predict_server_options.enforce_session_run_timeout =
353  server_options.enforce_session_run_timeout;
354  if (!server_options.thread_pool_factory_config_file.empty()) {
355  ThreadPoolFactoryConfig thread_pool_factory_config;
356  TF_RETURN_IF_ERROR(ParseProtoTextFile<ThreadPoolFactoryConfig>(
357  server_options.thread_pool_factory_config_file,
358  &thread_pool_factory_config));
359  TF_RETURN_IF_ERROR(ThreadPoolFactoryRegistry::CreateFromAny(
360  thread_pool_factory_config.thread_pool_factory_config(),
361  &thread_pool_factory_));
362  }
363  predict_server_options.thread_pool_factory = thread_pool_factory_.get();
364  prediction_service_ =
365  tf_serving_registry->GetCreatePredictionService()(predict_server_options);
366 
367  ::grpc::ServerBuilder builder;
368  // If defined, listen to a tcp port for gRPC/HTTP.
369  if (server_options.grpc_port != 0) {
370  builder.AddListeningPort(
371  server_address,
372  BuildServerCredentials(server_options.use_alts_credentials,
373  server_options.ssl_config_file));
374  }
375  // If defined, listen to a UNIX socket for gRPC.
376  if (!server_options.grpc_socket_path.empty()) {
377  const string grpc_socket_uri = "unix:" + server_options.grpc_socket_path;
378  builder.AddListeningPort(
379  grpc_socket_uri,
380  BuildServerCredentials(server_options.use_alts_credentials,
381  server_options.ssl_config_file));
382  }
383  builder.RegisterService(model_service_.get());
384  builder.RegisterService(prediction_service_.get());
385  if (server_options.enable_profiler) {
386  profiler_service_ = tensorflow::profiler::CreateProfilerService();
387  builder.RegisterService(profiler_service_.get());
388  LOG(INFO) << "Profiler service is enabled";
389  }
390  builder.SetMaxMessageSize(tensorflow::kint32max);
391  const std::vector<GrpcChannelArgument> channel_arguments =
392  parseGrpcChannelArgs(server_options.grpc_channel_arguments);
393  for (const GrpcChannelArgument& channel_argument : channel_arguments) {
394  // gRPC accept arguments of two types, int and string. We will attempt to
395  // parse each arg as int and pass it on as such if successful. Otherwise we
396  // will pass it as a string. gRPC will log arguments that were not accepted.
397  tensorflow::int32 value;
398  if (tensorflow::strings::safe_strto32(channel_argument.value, &value)) {
399  builder.AddChannelArgument(channel_argument.key, value);
400  } else {
401  builder.AddChannelArgument(channel_argument.key, channel_argument.value);
402  }
403  }
404 
405  ::grpc::ResourceQuota res_quota;
406  res_quota.SetMaxThreads(server_options.grpc_max_threads);
407  builder.SetResourceQuota(res_quota);
408  ::grpc::EnableDefaultHealthCheckService(
409  server_options.enable_grpc_healthcheck_service);
410  grpc_server_ = builder.BuildAndStart();
411 
412  if (server_options.enable_grpc_healthcheck_service) {
413  grpc_server_->GetHealthCheckService()->SetServingStatus("ModelService",
414  true);
415  grpc_server_->GetHealthCheckService()->SetServingStatus("PredictionService",
416  true);
417  }
418 
419  if (grpc_server_ == nullptr) {
420  return errors::InvalidArgument("Failed to BuildAndStart gRPC server");
421  }
422  if (server_options.grpc_port != 0) {
423  LOG(INFO) << "Running gRPC ModelServer at " << server_address << " ...";
424  }
425  if (!server_options.grpc_socket_path.empty()) {
426  LOG(INFO) << "Running gRPC ModelServer at UNIX socket "
427  << server_options.grpc_socket_path << " ...";
428  }
429 
430  if (server_options.http_port != 0) {
431  if (server_options.http_port != server_options.grpc_port) {
432  const string server_address =
433  "localhost:" + std::to_string(server_options.http_port);
434  MonitoringConfig monitoring_config;
435  if (!server_options.monitoring_config_file.empty()) {
436  TF_RETURN_IF_ERROR(ParseProtoTextFile<MonitoringConfig>(
437  server_options.monitoring_config_file, &monitoring_config));
438  }
439  http_server_ = CreateAndStartHttpServer(
440  server_options.http_port, server_options.http_num_threads,
441  server_options.http_timeout_in_ms, monitoring_config,
442  server_core_.get());
443  if (http_server_ != nullptr) {
444  LOG(INFO) << "Exporting HTTP/REST API at:" << server_address << " ...";
445  } else {
446  LOG(ERROR) << "Failed to start HTTP Server at " << server_address;
447  }
448  } else {
449  LOG(ERROR) << "server_options.http_port cannot be same as grpc_port. "
450  << "Please use a different port for HTTP/REST API. "
451  << "Skipped exporting HTTP/REST API.";
452  }
453  }
454  return absl::OkStatus();
455 }
456 
457 void Server::WaitForTermination() {
458  if (http_server_ != nullptr) {
459  http_server_->WaitForTermination();
460  }
461  if (grpc_server_ != nullptr) {
462  grpc_server_->Wait();
463  }
464 }
465 
466 } // namespace main
467 } // namespace serving
468 } // namespace tensorflow