TensorFlow Serving C++ API Documentation
model_service_impl.cc
1 /* Copyright 2018 Google Inc. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7  http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_serving/model_servers/model_service_impl.h"
17 
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "xla/tsl/lib/monitoring/collected_metrics.h"
25 #include "xla/tsl/lib/monitoring/collection_registry.h"
26 #include "tensorflow_serving/model_servers/get_model_status_impl.h"
27 #include "tensorflow_serving/model_servers/grpc_status_util.h"
28 #include "tensorflow_serving/util/status_util.h"
29 
30 namespace tensorflow {
31 namespace serving {
32 
33 ::grpc::Status ModelServiceImpl::GetModelStatus(
34  ::grpc::ServerContext *context, const GetModelStatusRequest *request,
35  GetModelStatusResponse *response) {
36  const ::grpc::Status status = tensorflow::serving::ToGRPCStatus(
37  GetModelStatusImpl::GetModelStatus(core_, *request, response));
38  if (!status.ok()) {
39  VLOG(1) << "GetModelStatus failed: " << status.error_message();
40  }
41  return status;
42 }
43 
44 ::grpc::Status ModelServiceImpl::HandleReloadConfigRequest(
45  ::grpc::ServerContext *context, const ReloadConfigRequest *request,
46  ReloadConfigResponse *response) {
47  ModelServerConfig server_config = request->config();
48  Status status;
49  const absl::flat_hash_map<std::string, int64_t> old_metric_values =
50  GetMetrics(request);
51  switch (server_config.config_case()) {
52  case ModelServerConfig::kModelConfigList: {
53  const ModelConfigList list = server_config.model_config_list();
54 
55  for (int index = 0; index < list.config_size(); index++) {
56  const ModelConfig config = list.config(index);
57  LOG(INFO) << "\nConfig entry"
58  << "\n\tindex : " << index
59  << "\n\tpath : " << config.base_path()
60  << "\n\tname : " << config.name()
61  << "\n\tplatform : " << config.model_platform();
62  }
63  status = core_->ReloadConfig(server_config);
64  break;
65  }
66  default:
67  status = errors::InvalidArgument(
68  "ServerModelConfig type not supported by HandleReloadConfigRequest."
69  " Only ModelConfigList is currently supported");
70  }
71 
72  if (!status.ok()) {
73  LOG(ERROR) << "ReloadConfig failed: " << status.message();
74  }
75  const absl::flat_hash_map<std::string, int64_t> new_metric_values =
76  GetMetrics(request);
77  RecordMetricsIncrease(old_metric_values, new_metric_values, response);
78 
79  const StatusProto status_proto = ToStatusProto(status);
80  *response->mutable_status() = status_proto;
81  return ToGRPCStatus(status);
82 }
83 
84 absl::flat_hash_map<std::string, int64_t> ModelServiceImpl::GetMetrics(
85  const ReloadConfigRequest *request) {
86  absl::flat_hash_map<std::string, int64_t> metric_values = {};
87  const tsl::monitoring::CollectionRegistry::CollectMetricsOptions options;
88  tsl::monitoring::CollectionRegistry *collection_registry =
89  tsl::monitoring::CollectionRegistry::Default();
90  std::unique_ptr<tsl::monitoring::CollectedMetrics> collected_metrics =
91  collection_registry->CollectMetrics(options);
92 
93  for (const std::string &metric_name : request->metric_names()) {
94  int64_t metric_value = 0;
95  auto it = collected_metrics->point_set_map.find(metric_name);
96  if (it != collected_metrics->point_set_map.end()) {
97  std::vector<std::unique_ptr<tsl::monitoring::Point>> *points =
98  &it->second->points;
99  if (!points->empty()) {
100  metric_value = (*points)[0]->int64_value;
101  }
102  }
103  metric_values.insert({metric_name, metric_value});
104  }
105  return metric_values;
106 }
107 
108 void ModelServiceImpl::RecordMetricsIncrease(
109  const absl::flat_hash_map<std::string, int64_t> &old_metric_values,
110  const absl::flat_hash_map<std::string, int64_t> &new_metric_values,
111  ReloadConfigResponse *response) {
112  for (const auto &[metric_name, metric_value] : new_metric_values) {
113  Metric metric;
114  metric.set_name(metric_name);
115  int64_t old_metric_value = old_metric_values.contains(metric_name)
116  ? old_metric_values.at(metric_name)
117  : 0;
118  metric.set_int64_value_increase(metric_value - old_metric_value);
119  *response->add_metric() = metric;
120  }
121 }
122 } // namespace serving
123 } // namespace tensorflow
virtual Status ReloadConfig(const ModelServerConfig &config) TF_LOCKS_EXCLUDED(config_mu_)
Definition: server_core.cc:447