16 #include "tensorflow_serving/model_servers/model_service_impl.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "xla/tsl/lib/monitoring/collected_metrics.h"
25 #include "xla/tsl/lib/monitoring/collection_registry.h"
26 #include "tensorflow_serving/model_servers/get_model_status_impl.h"
27 #include "tensorflow_serving/model_servers/grpc_status_util.h"
28 #include "tensorflow_serving/util/status_util.h"
30 namespace tensorflow {
33 ::grpc::Status ModelServiceImpl::GetModelStatus(
34 ::grpc::ServerContext *context,
const GetModelStatusRequest *request,
35 GetModelStatusResponse *response) {
36 const ::grpc::Status status = tensorflow::serving::ToGRPCStatus(
37 GetModelStatusImpl::GetModelStatus(core_, *request, response));
39 VLOG(1) <<
"GetModelStatus failed: " << status.error_message();
44 ::grpc::Status ModelServiceImpl::HandleReloadConfigRequest(
45 ::grpc::ServerContext *context,
const ReloadConfigRequest *request,
46 ReloadConfigResponse *response) {
47 ModelServerConfig server_config = request->config();
49 const absl::flat_hash_map<std::string, int64_t> old_metric_values =
51 switch (server_config.config_case()) {
52 case ModelServerConfig::kModelConfigList: {
53 const ModelConfigList list = server_config.model_config_list();
55 for (
int index = 0; index < list.config_size(); index++) {
56 const ModelConfig config = list.config(index);
57 LOG(INFO) <<
"\nConfig entry"
58 <<
"\n\tindex : " << index
59 <<
"\n\tpath : " << config.base_path()
60 <<
"\n\tname : " << config.name()
61 <<
"\n\tplatform : " << config.model_platform();
67 status = errors::InvalidArgument(
68 "ServerModelConfig type not supported by HandleReloadConfigRequest."
69 " Only ModelConfigList is currently supported");
73 LOG(ERROR) <<
"ReloadConfig failed: " << status.message();
75 const absl::flat_hash_map<std::string, int64_t> new_metric_values =
77 RecordMetricsIncrease(old_metric_values, new_metric_values, response);
79 const StatusProto status_proto = ToStatusProto(status);
80 *response->mutable_status() = status_proto;
81 return ToGRPCStatus(status);
84 absl::flat_hash_map<std::string, int64_t> ModelServiceImpl::GetMetrics(
85 const ReloadConfigRequest *request) {
86 absl::flat_hash_map<std::string, int64_t> metric_values = {};
87 const tsl::monitoring::CollectionRegistry::CollectMetricsOptions options;
88 tsl::monitoring::CollectionRegistry *collection_registry =
89 tsl::monitoring::CollectionRegistry::Default();
90 std::unique_ptr<tsl::monitoring::CollectedMetrics> collected_metrics =
91 collection_registry->CollectMetrics(options);
93 for (
const std::string &metric_name : request->metric_names()) {
94 int64_t metric_value = 0;
95 auto it = collected_metrics->point_set_map.find(metric_name);
96 if (it != collected_metrics->point_set_map.end()) {
97 std::vector<std::unique_ptr<tsl::monitoring::Point>> *points =
99 if (!points->empty()) {
100 metric_value = (*points)[0]->int64_value;
103 metric_values.insert({metric_name, metric_value});
105 return metric_values;
108 void ModelServiceImpl::RecordMetricsIncrease(
109 const absl::flat_hash_map<std::string, int64_t> &old_metric_values,
110 const absl::flat_hash_map<std::string, int64_t> &new_metric_values,
111 ReloadConfigResponse *response) {
112 for (
const auto &[metric_name, metric_value] : new_metric_values) {
114 metric.set_name(metric_name);
115 int64_t old_metric_value = old_metric_values.contains(metric_name)
116 ? old_metric_values.at(metric_name)
118 metric.set_int64_value_increase(metric_value - old_metric_value);
119 *response->add_metric() = metric;
virtual Status ReloadConfig(const ModelServerConfig &config) TF_LOCKS_EXCLUDED(config_mu_)