16 #ifndef TENSORFLOW_SERVING_CORE_SOURCE_ROUTER_H_
17 #define TENSORFLOW_SERVING_CORE_SOURCE_ROUTER_H_
23 #include "tensorflow/core/lib/core/notification.h"
24 #include "tensorflow_serving/core/source.h"
25 #include "tensorflow_serving/core/source_adapter.h"
26 #include "tensorflow_serving/core/target.h"
28 namespace tensorflow {
54 static constexpr
int kNoRoute = -1;
63 std::vector<Source<T>*> GetOutputPorts();
66 void SetAspiredVersions(
const StringPiece servable_name,
75 virtual int num_output_ports()
const = 0;
82 virtual int Route(
const StringPiece servable_name,
88 std::vector<std::unique_ptr<SourceAdapter<T, T>>> output_ports_;
92 Notification output_ports_created_;
102 template <
typename T>
109 std::vector<ServableData<T>> Adapt(
110 const StringPiece servable_name,
116 template <
typename T>
118 const StringPiece servable_name, std::vector<
ServableData<T>> versions) {
124 template <
typename T>
125 SourceRouter<T>::~SourceRouter() {}
127 template <
typename T>
128 std::vector<Source<T>*> SourceRouter<T>::GetOutputPorts() {
129 if (!output_ports_created_.HasBeenNotified()) {
130 int num_ports = num_output_ports();
132 LOG(ERROR) <<
"SourceRouter abstraction used improperly; "
133 "num_output_ports() must return a number greater than 0";
137 for (
int i = 0; i < num_ports; ++i) {
138 output_ports_.emplace_back(
new internal::IdentitySourceAdapter<T>);
140 output_ports_created_.Notify();
143 std::vector<Source<T>*> result;
144 for (
auto& output_port : output_ports_) {
145 result.push_back(output_port.get());
150 template <
typename T>
151 void SourceRouter<T>::SetAspiredVersions(
152 const StringPiece servable_name, std::vector<ServableData<T>> versions) {
153 output_ports_created_.WaitForNotification();
154 int output_port = Route(servable_name, versions);
155 if (output_port == kNoRoute) {
158 if (output_port < 0 || output_port > output_ports_.size() - 1) {
160 <<
"SourceRouter abstraction used improperly; Route() must return "
161 "kNoRoute or a value in [0, num_output_ports()-1]; suppressing the "
162 "aspired-versions request";
166 output_ports_[output_port]->SetAspiredVersions(servable_name,
167 std::move(versions));