16 #include "tensorflow_serving/model_servers/server.h"
26 #include "google/protobuf/wrappers.pb.h"
27 #include "grpc/grpc.h"
28 #include "grpcpp/health_check_service_interface.h"
29 #include "grpcpp/resource_quota.h"
30 #include "grpcpp/security/server_credentials.h"
31 #include "grpcpp/server_builder.h"
32 #include "grpcpp/server_context.h"
33 #include "grpcpp/support/status.h"
34 #include "absl/memory/memory.h"
35 #include "tensorflow/c/c_api.h"
36 #include "tensorflow/cc/saved_model/tag_constants.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/core/lib/strings/numbers.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/platform/env.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
44 #include "tensorflow/core/protobuf/config.pb.h"
45 #include "tsl/platform/errors.h"
46 #include "tensorflow_serving/config/model_server_config.pb.h"
47 #include "tensorflow_serving/config/monitoring_config.pb.h"
48 #include "tensorflow_serving/config/platform_config.pb.h"
49 #include "tensorflow_serving/config/ssl_config.pb.h"
50 #include "tensorflow_serving/core/availability_preserving_policy.h"
51 #include "tensorflow_serving/model_servers/grpc_status_util.h"
52 #include "tensorflow_serving/model_servers/model_platform_types.h"
53 #include "tensorflow_serving/model_servers/server_core.h"
54 #include "tensorflow_serving/model_servers/server_init.h"
55 #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
56 #include "tensorflow_serving/servables/tensorflow/thread_pool_factory_config.pb.h"
57 #include "tensorflow_serving/servables/tensorflow/util.h"
58 #include "tensorflow_serving/util/proto_util.h"
60 namespace tensorflow {
66 tensorflow::Status LoadCustomModelConfig(
67 const ::google::protobuf::Any& any,
68 EventBus<ServableState>* servable_event_bus,
69 UniquePtrWithDeps<AspiredVersionsManager>* manager) {
71 <<
"ModelServer does not yet support custom model config.";
74 ModelServerConfig BuildSingleModelConfig(
const string& model_name,
75 const string& model_base_path) {
76 ModelServerConfig config;
77 LOG(INFO) <<
"Building single TensorFlow model file config: "
78 <<
" model_name: " << model_name
79 <<
" model_base_path: " << model_base_path;
80 tensorflow::serving::ModelConfig* single_model =
81 config.mutable_model_config_list()->add_config();
82 single_model->set_name(model_name);
83 single_model->set_base_path(model_base_path);
84 single_model->set_model_platform(
85 tensorflow::serving::kTensorFlowModelPlatform);
91 struct GrpcChannelArgument {
98 std::vector<GrpcChannelArgument> parseGrpcChannelArgs(
99 const string& channel_arguments_str) {
100 const std::vector<string> channel_arguments =
101 tensorflow::str_util::Split(channel_arguments_str,
",");
102 std::vector<GrpcChannelArgument> result;
103 for (
const string& channel_argument : channel_arguments) {
104 const std::vector<string> key_val =
105 tensorflow::str_util::Split(channel_argument,
"=");
106 result.push_back({key_val[0], key_val[1]});
114 std::shared_ptr<::grpc::ServerCredentials> BuildServerCredentials(
115 bool use_alts_credentials,
const string& ssl_config_file) {
116 if (use_alts_credentials) {
117 LOG(INFO) <<
"Using ALTS credentials";
118 ::grpc::experimental::AltsServerCredentialsOptions alts_opts;
119 return ::grpc::experimental::AltsServerCredentials(alts_opts);
120 }
else if (ssl_config_file.empty()) {
121 LOG(INFO) <<
"Using InsecureServerCredentials";
122 return ::grpc::InsecureServerCredentials();
125 SSLConfig ssl_config;
126 TF_CHECK_OK(ParseProtoTextFile<SSLConfig>(ssl_config_file, &ssl_config));
127 LOG(INFO) <<
"Using SSL credentials";
129 ::grpc::SslServerCredentialsOptions ssl_ops(
130 ssl_config.client_verify()
131 ? GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY
132 : GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE);
134 ssl_ops.force_client_auth = ssl_config.client_verify();
136 if (!ssl_config.custom_ca().empty()) {
137 ssl_ops.pem_root_certs = ssl_config.custom_ca();
140 ::grpc::SslServerCredentialsOptions::PemKeyCertPair keycert = {
141 ssl_config.server_key(), ssl_config.server_cert()};
143 ssl_ops.pem_key_cert_pairs.push_back(keycert);
145 return ::grpc::SslServerCredentials(ssl_ops);
150 Server::Options::Options()
151 : model_name(
"default"),
152 saved_model_tags(tensorflow::kSavedModelTagServe) {}
158 fs_config_polling_thread_.reset();
159 WaitForTermination();
162 void Server::PollFilesystemAndReloadConfig(
const string& config_file_path) {
163 ModelServerConfig config;
164 const Status read_status =
165 ParseProtoTextFile<ModelServerConfig>(config_file_path, &config);
166 if (!read_status.ok()) {
167 LOG(ERROR) <<
"Failed to read ModelServerConfig file: "
168 << read_status.message();
172 const Status reload_status = server_core_->ReloadConfig(config);
173 if (!reload_status.ok()) {
174 LOG(ERROR) <<
"PollFilesystemAndReloadConfig failed to ReloadConfig: "
175 << reload_status.message();
179 Status Server::BuildAndStart(
const Options& server_options) {
180 if (server_options.grpc_port == 0 &&
181 server_options.grpc_socket_path.empty()) {
182 return errors::InvalidArgument(
183 "At least one of server_options.grpc_port or "
184 "server_options.grpc_socket_path must be set.");
187 if (server_options.use_alts_credentials &&
188 !server_options.ssl_config_file.empty()) {
189 return errors::InvalidArgument(
190 "Either use_alts_credentials must be false or "
191 "ssl_config_file must be empty.");
194 if (server_options.model_base_path.empty() &&
195 server_options.model_config_file.empty()) {
196 return errors::InvalidArgument(
197 "Both server_options.model_base_path and "
198 "server_options.model_config_file are empty!");
201 SetSignatureMethodNameCheckFeature(
202 server_options.enable_signature_method_name_check);
206 ServerCore::Options options;
209 if (server_options.model_config_file.empty()) {
210 options.model_server_config = BuildSingleModelConfig(
211 server_options.model_name, server_options.model_base_path);
213 TF_RETURN_IF_ERROR(ParseProtoTextFile<ModelServerConfig>(
214 server_options.model_config_file, &options.model_server_config));
217 auto* tf_serving_registry =
218 init::TensorflowServingFunctionRegistration::GetRegistry();
220 if (server_options.platform_config_file.empty()) {
221 SessionBundleConfig session_bundle_config;
223 if (server_options.enable_batching) {
224 BatchingParameters* batching_parameters =
225 session_bundle_config.mutable_batching_parameters();
226 if (server_options.batching_parameters_file.empty()) {
227 batching_parameters->mutable_thread_pool_name()->set_value(
228 "model_server_batch_threads");
230 TF_RETURN_IF_ERROR(ParseProtoTextFile<BatchingParameters>(
231 server_options.batching_parameters_file, batching_parameters));
233 if (server_options.enable_per_model_batching_params) {
234 session_bundle_config.set_enable_per_model_batching_params(
true);
236 }
else if (!server_options.batching_parameters_file.empty()) {
237 return errors::InvalidArgument(
238 "server_options.batching_parameters_file is set without setting "
239 "server_options.enable_batching to true.");
242 if (!server_options.tensorflow_session_config_file.empty()) {
244 ParseProtoTextFile(server_options.tensorflow_session_config_file,
245 session_bundle_config.mutable_session_config()));
248 session_bundle_config.mutable_session_config()
249 ->mutable_gpu_options()
250 ->set_per_process_gpu_memory_fraction(
251 server_options.per_process_gpu_memory_fraction);
253 if (server_options.tensorflow_intra_op_parallelism > 0 &&
254 server_options.tensorflow_inter_op_parallelism > 0 &&
255 server_options.tensorflow_session_parallelism > 0){
256 return errors::InvalidArgument(
"Either configure "
257 "server_options.tensorflow_session_parallelism "
258 "or (server_options.tensorflow_intra_op_parallelism, "
259 "server_options.tensorflow_inter_op_parallelism) separately. "
260 "You cannot configure all.");
261 }
else if (server_options.tensorflow_intra_op_parallelism > 0 ||
262 server_options.tensorflow_inter_op_parallelism > 0){
263 session_bundle_config.mutable_session_config()
264 ->set_intra_op_parallelism_threads(
265 server_options.tensorflow_intra_op_parallelism);
266 session_bundle_config.mutable_session_config()
267 ->set_inter_op_parallelism_threads(
268 server_options.tensorflow_inter_op_parallelism);
270 session_bundle_config.mutable_session_config()
271 ->set_intra_op_parallelism_threads(
272 server_options.tensorflow_session_parallelism);
273 session_bundle_config.mutable_session_config()
274 ->set_inter_op_parallelism_threads(
275 server_options.tensorflow_session_parallelism);
278 const std::vector<string> tags =
279 tensorflow::str_util::Split(server_options.saved_model_tags,
",");
280 for (
const string& tag : tags) {
281 *session_bundle_config.add_saved_model_tags() = tag;
283 session_bundle_config.set_enable_model_warmup(
284 server_options.enable_model_warmup);
285 if (server_options.num_request_iterations_for_warmup > 0) {
286 session_bundle_config.mutable_model_warmup_options()
287 ->mutable_num_request_iterations()
288 ->set_value(server_options.num_request_iterations_for_warmup);
290 session_bundle_config.set_remove_unused_fields_from_bundle_metagraph(
291 server_options.remove_unused_fields_from_bundle_metagraph);
292 session_bundle_config.set_prefer_tflite_model(
293 server_options.prefer_tflite_model);
294 session_bundle_config.set_num_tflite_interpreters_per_pool(
295 server_options.num_tflite_interpreters_per_pool);
296 session_bundle_config.set_num_tflite_pools(server_options.num_tflite_pools);
297 session_bundle_config.set_mixed_precision(server_options.mixed_precision);
299 TF_RETURN_IF_ERROR(tf_serving_registry->GetSetupPlatformConfigMap()(
300 session_bundle_config, options.platform_config_map));
302 TF_RETURN_IF_ERROR(ParseProtoTextFile<PlatformConfigMap>(
303 server_options.platform_config_file, &options.platform_config_map));
304 TF_RETURN_IF_ERROR(tf_serving_registry->GetUpdatePlatformConfigMap()(
305 options.platform_config_map));
308 options.custom_model_config_loader = &LoadCustomModelConfig;
309 options.aspired_version_policy =
310 std::unique_ptr<AspiredVersionPolicy>(
new AvailabilityPreservingPolicy);
311 options.num_load_threads = server_options.num_load_threads;
312 options.num_unload_threads = server_options.num_unload_threads;
313 options.max_num_load_retries = server_options.max_num_load_retries;
314 options.load_retry_interval_micros =
315 server_options.load_retry_interval_micros;
316 options.file_system_poll_wait_seconds =
317 server_options.file_system_poll_wait_seconds;
318 options.flush_filesystem_caches = server_options.flush_filesystem_caches;
319 options.allow_version_labels_for_unavailable_models =
320 server_options.allow_version_labels_for_unavailable_models;
321 options.force_allow_any_version_labels_for_unavailable_models =
322 server_options.force_allow_any_version_labels_for_unavailable_models;
323 options.enable_cors_support = server_options.enable_cors_support;
325 TF_RETURN_IF_ERROR(ServerCore::Create(std::move(options), &server_core_));
330 if (server_options.fs_model_config_poll_wait_seconds > 0 &&
331 !server_options.model_config_file.empty()) {
332 PeriodicFunction::Options pf_options;
333 pf_options.thread_name_prefix =
"Server_fs_model_config_poll_thread";
335 const string model_config_file = server_options.model_config_file;
336 fs_config_polling_thread_.reset(
new PeriodicFunction(
337 [
this, model_config_file] {
338 this->PollFilesystemAndReloadConfig(model_config_file);
340 server_options.fs_model_config_poll_wait_seconds *
341 tensorflow::EnvTime::kSecondsToMicros,
346 const string server_address =
347 "0.0.0.0:" + std::to_string(server_options.grpc_port);
348 model_service_ = absl::make_unique<ModelServiceImpl>(server_core_.get());
350 PredictionServiceOptions predict_server_options;
351 predict_server_options.server_core = server_core_.get();
352 predict_server_options.enforce_session_run_timeout =
353 server_options.enforce_session_run_timeout;
354 if (!server_options.thread_pool_factory_config_file.empty()) {
355 ThreadPoolFactoryConfig thread_pool_factory_config;
356 TF_RETURN_IF_ERROR(ParseProtoTextFile<ThreadPoolFactoryConfig>(
357 server_options.thread_pool_factory_config_file,
358 &thread_pool_factory_config));
359 TF_RETURN_IF_ERROR(ThreadPoolFactoryRegistry::CreateFromAny(
360 thread_pool_factory_config.thread_pool_factory_config(),
361 &thread_pool_factory_));
363 predict_server_options.thread_pool_factory = thread_pool_factory_.get();
364 prediction_service_ =
365 tf_serving_registry->GetCreatePredictionService()(predict_server_options);
367 ::grpc::ServerBuilder builder;
369 if (server_options.grpc_port != 0) {
370 builder.AddListeningPort(
372 BuildServerCredentials(server_options.use_alts_credentials,
373 server_options.ssl_config_file));
376 if (!server_options.grpc_socket_path.empty()) {
377 const string grpc_socket_uri =
"unix:" + server_options.grpc_socket_path;
378 builder.AddListeningPort(
380 BuildServerCredentials(server_options.use_alts_credentials,
381 server_options.ssl_config_file));
383 builder.RegisterService(model_service_.get());
384 builder.RegisterService(prediction_service_.get());
385 if (server_options.enable_profiler) {
386 profiler_service_ = tensorflow::profiler::CreateProfilerService();
387 builder.RegisterService(profiler_service_.get());
388 LOG(INFO) <<
"Profiler service is enabled";
390 builder.SetMaxMessageSize(tensorflow::kint32max);
391 const std::vector<GrpcChannelArgument> channel_arguments =
392 parseGrpcChannelArgs(server_options.grpc_channel_arguments);
393 for (
const GrpcChannelArgument& channel_argument : channel_arguments) {
397 tensorflow::int32 value;
398 if (tensorflow::strings::safe_strto32(channel_argument.value, &value)) {
399 builder.AddChannelArgument(channel_argument.key, value);
401 builder.AddChannelArgument(channel_argument.key, channel_argument.value);
405 ::grpc::ResourceQuota res_quota;
406 res_quota.SetMaxThreads(server_options.grpc_max_threads);
407 builder.SetResourceQuota(res_quota);
408 ::grpc::EnableDefaultHealthCheckService(
409 server_options.enable_grpc_healthcheck_service);
410 grpc_server_ = builder.BuildAndStart();
412 if (server_options.enable_grpc_healthcheck_service) {
413 grpc_server_->GetHealthCheckService()->SetServingStatus(
"ModelService",
415 grpc_server_->GetHealthCheckService()->SetServingStatus(
"PredictionService",
419 if (grpc_server_ ==
nullptr) {
420 return errors::InvalidArgument(
"Failed to BuildAndStart gRPC server");
422 if (server_options.grpc_port != 0) {
423 LOG(INFO) <<
"Running gRPC ModelServer at " << server_address <<
" ...";
425 if (!server_options.grpc_socket_path.empty()) {
426 LOG(INFO) <<
"Running gRPC ModelServer at UNIX socket "
427 << server_options.grpc_socket_path <<
" ...";
430 if (server_options.http_port != 0) {
431 if (server_options.http_port != server_options.grpc_port) {
432 const string server_address =
433 "localhost:" + std::to_string(server_options.http_port);
434 MonitoringConfig monitoring_config;
435 if (!server_options.monitoring_config_file.empty()) {
436 TF_RETURN_IF_ERROR(ParseProtoTextFile<MonitoringConfig>(
437 server_options.monitoring_config_file, &monitoring_config));
439 http_server_ = CreateAndStartHttpServer(
440 server_options.http_port, server_options.http_num_threads,
441 server_options.http_timeout_in_ms, monitoring_config,
443 if (http_server_ !=
nullptr) {
444 LOG(INFO) <<
"Exporting HTTP/REST API at:" << server_address <<
" ...";
446 LOG(ERROR) <<
"Failed to start HTTP Server at " << server_address;
449 LOG(ERROR) <<
"server_options.http_port cannot be same as grpc_port. "
450 <<
"Please use a different port for HTTP/REST API. "
451 <<
"Skipped exporting HTTP/REST API.";
454 return absl::OkStatus();
457 void Server::WaitForTermination() {
458 if (http_server_ !=
nullptr) {
459 http_server_->WaitForTermination();
461 if (grpc_server_ !=
nullptr) {
462 grpc_server_->Wait();