16 #include "tensorflow_serving/core/test_util/session_test_util.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/strip.h"
24 #include "tensorflow/core/common_runtime/session_factory.h"
25 #include "tensorflow/core/public/session.h"
27 namespace tensorflow {
32 using NewSessionHook = std::function<Status(
const SessionOptions&)>;
33 NewSessionHook new_session_hook_;
35 NewSessionHook GetNewSessionHook() {
return new_session_hook_; }
43 class DelegatingSessionFactory :
public SessionFactory {
45 DelegatingSessionFactory() {}
47 bool AcceptsOptions(
const SessionOptions& options)
override {
48 return absl::StartsWith(options.target,
"new_session_hook/");
51 Status NewSession(
const SessionOptions& options,
52 Session** out_session)
override {
53 auto actual_session_options = options;
54 actual_session_options.target = std::string(
55 absl::StripPrefix(options.target, kNewSessionHookSessionTargetPrefix));
56 auto new_session_hook = GetNewSessionHook();
57 if (new_session_hook) {
58 TF_RETURN_IF_ERROR(new_session_hook(actual_session_options));
60 Session* actual_session;
62 tensorflow::NewSession(actual_session_options, &actual_session));
63 *out_session = actual_session;
68 class DelegatingSessionRegistrar {
70 DelegatingSessionRegistrar() {
71 SessionFactory::Register(
"DELEGATING_SESSION",
72 new DelegatingSessionFactory());
75 static DelegatingSessionRegistrar registrar;
79 const char kNewSessionHookSessionTargetPrefix[] =
"new_session_hook/";
81 void SetNewSessionHook(NewSessionHook hook) {
82 new_session_hook_ = std::move(hook);