16 #include "tensorflow_serving/resources/resource_util.h"
25 #include "google/protobuf/wrappers.pb.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/platform/types.h"
31 #include "tsl/platform/errors.h"
33 namespace tensorflow {
39 bool RawResourcesEqual(
const Resource& lhs,
const Resource& rhs) {
40 if (lhs.device() != rhs.device()) {
44 if (lhs.has_device_instance() != rhs.has_device_instance()) {
47 if (lhs.has_device_instance()) {
48 if (lhs.device_instance().value() != rhs.device_instance().value()) {
53 return lhs.kind() == rhs.kind();
57 std::map<string, uint32> StripDevicesWithZeroInstances(
58 const std::map<string, uint32>& devices) {
59 std::map<string, uint32> result;
60 for (
const auto& entry : devices) {
61 if (entry.second > 0) {
70 ResourceAllocation::Entry* FindMutableEntry(
const Resource& resource,
71 ResourceAllocation* allocation) {
72 for (ResourceAllocation::Entry& entry :
73 *allocation->mutable_resource_quantities()) {
74 if (RawResourcesEqual(entry.resource(), resource)) {
83 ResourceAllocation::Entry* FindOrInsertMutableEntry(
84 const Resource& resource, ResourceAllocation* allocation) {
85 ResourceAllocation::Entry* entry = FindMutableEntry(resource, allocation);
86 if (entry ==
nullptr) {
87 entry = allocation->add_resource_quantities();
88 *entry->mutable_resource() = resource;
89 entry->set_quantity(0);
96 ResourceUtil::ResourceUtil(
const Options& options)
97 : devices_(StripDevicesWithZeroInstances(options.devices)) {}
99 Status ResourceUtil::VerifyValidity(
100 const ResourceAllocation& allocation)
const {
101 const Status result = [
this, &allocation]() -> Status {
103 ResourceAllocation validated_entries;
104 for (
const auto& entry : allocation.resource_quantities()) {
105 TF_RETURN_IF_ERROR(VerifyFunctionInternal(
106 [&]() {
return VerifyResourceValidity(entry.resource()); },
107 DCHECKFailOption::kDoNotDCHECKFail));
109 if (FindMutableEntry(entry.resource(), &validated_entries) !=
nullptr) {
110 return errors::InvalidArgument(
111 "Invalid resource allocation: Repeated resource\n",
112 entry.resource().DebugString(),
"in allocation\n",
113 allocation.DebugString());
116 *validated_entries.add_resource_quantities() = entry;
121 LOG(ERROR) << result;
127 Status ResourceUtil::VerifyResourceValidity(
const Resource& resource)
const {
128 const Status result = [
this, &resource]() -> Status {
129 auto it = devices_.find(resource.device());
130 if (it == devices_.end()) {
131 return errors::InvalidArgument(
132 "Invalid resource allocation: Invalid device ", resource.device());
134 const uint32 num_instances = it->second;
135 if (resource.has_device_instance() &&
136 resource.device_instance().value() >= num_instances) {
137 return errors::InvalidArgument(
138 "Invalid resource allocation: Invalid device instance ",
139 resource.device(),
":", resource.device_instance().value());
144 LOG(ERROR) << result;
150 Status ResourceUtil::VerifyOverrideDeviceValidity(
151 const ResourceAllocation& base_allocation,
152 const ResourceAllocation& override_allocation)
const {
153 absl::flat_hash_set<std::pair<std::string, std::string>>
154 base_device_kind_pairs;
155 for (
const auto& entry : base_allocation.resource_quantities()) {
156 base_device_kind_pairs.insert(
157 {entry.resource().device(), entry.resource().kind()});
159 for (
const auto& entry : override_allocation.resource_quantities()) {
160 if (base_device_kind_pairs.find(
161 {entry.resource().device(), entry.resource().kind()}) ==
162 base_device_kind_pairs.end()) {
163 return errors::InvalidArgument(
164 "Invalid resource allocation: device-kind from override "
165 "resource was not found in base resource: ",
166 entry.resource().DebugString());
172 ResourceAllocation ResourceUtil::Normalize(
173 const ResourceAllocation& allocation)
const {
174 return NormalizeResourceAllocation(allocation);
177 bool ResourceUtil::IsNormalized(
const ResourceAllocation& allocation)
const {
178 return IsResourceAllocationNormalized(allocation);
181 bool ResourceUtil::IsBound(
const ResourceAllocation& allocation)
const {
182 return IsBoundNormalized(Normalize(allocation));
185 Resource ResourceUtil::CreateBoundResource(
const string& device,
187 uint32 device_instance)
const {
188 DCHECK(devices_.find(device) != devices_.end());
190 resource.set_device(device);
191 resource.set_kind(kind);
192 resource.mutable_device_instance()->set_value(device_instance);
196 uint64_t ResourceUtil::GetQuantity(
const Resource& resource,
197 const ResourceAllocation& allocation)
const {
198 DCHECK(devices_.find(resource.device()) != devices_.end());
199 for (
const ResourceAllocation::Entry& entry :
200 allocation.resource_quantities()) {
201 if (ResourcesEqual(entry.resource(), resource)) {
202 return entry.quantity();
208 void ResourceUtil::SetQuantity(
const Resource& resource, uint64_t quantity,
209 ResourceAllocation* allocation)
const {
210 DCHECK(devices_.find(resource.device()) != devices_.end());
211 for (
int i = 0; i < allocation->resource_quantities().size(); ++i) {
212 ResourceAllocation::Entry* entry =
213 allocation->mutable_resource_quantities(i);
214 if (ResourcesEqual(entry->resource(), resource)) {
215 entry->set_quantity(quantity);
219 ResourceAllocation::Entry* new_entry = allocation->add_resource_quantities();
220 *new_entry->mutable_resource() = resource;
221 new_entry->set_quantity(quantity);
224 void ResourceUtil::Add(
const ResourceAllocation& to_add,
225 ResourceAllocation* base)
const {
226 *base = Normalize(*base);
227 return AddNormalized(Normalize(to_add), base);
230 bool ResourceUtil::Subtract(
const ResourceAllocation& to_subtract,
231 ResourceAllocation* base)
const {
232 *base = Normalize(*base);
233 return SubtractNormalized(Normalize(to_subtract), base);
236 void ResourceUtil::Multiply(uint64_t multiplier,
237 ResourceAllocation* base)
const {
238 *base = Normalize(*base);
239 return MultiplyNormalized(multiplier, base);
242 bool ResourceUtil::Equal(
const ResourceAllocation& lhs,
243 const ResourceAllocation& rhs)
const {
244 return EqualNormalized(Normalize(lhs), Normalize(rhs));
247 bool ResourceUtil::ResourcesEqual(
const Resource& lhs,
248 const Resource& rhs)
const {
249 return ResourcesEqualNormalized(NormalizeResource(lhs),
250 NormalizeResource(rhs));
253 bool ResourceUtil::LessThanOrEqual(
const ResourceAllocation& lhs,
254 const ResourceAllocation& rhs)
const {
255 return LessThanOrEqualNormalized(Normalize(lhs), Normalize(rhs));
258 ResourceAllocation ResourceUtil::Overbind(
259 const ResourceAllocation& allocation)
const {
260 return OverbindNormalized(Normalize(allocation));
263 ResourceAllocation ResourceUtil::Max(
const ResourceAllocation& lhs,
264 const ResourceAllocation& rhs)
const {
265 return MaxNormalized(Normalize(lhs), Normalize(rhs));
268 ResourceAllocation ResourceUtil::Min(
const ResourceAllocation& lhs,
269 const ResourceAllocation& rhs)
const {
270 return MinNormalized(Normalize(lhs), Normalize(rhs));
273 ResourceAllocation ResourceUtil::NormalizeResourceAllocation(
274 const ResourceAllocation& allocation)
const {
275 if (!VerifyFunctionInternal([&]() {
return VerifyValidity(allocation); },
276 DCHECKFailOption::kDoDCHECKFail)
281 ResourceAllocation normalized;
282 for (
const ResourceAllocation::Entry& entry :
283 allocation.resource_quantities()) {
284 if (entry.quantity() == 0) {
288 ResourceAllocation::Entry* normalized_entry =
289 normalized.add_resource_quantities();
290 *normalized_entry->mutable_resource() = NormalizeResource(entry.resource());
291 normalized_entry->set_quantity(entry.quantity());
296 bool ResourceUtil::IsResourceAllocationNormalized(
297 const ResourceAllocation& allocation)
const {
298 if (!VerifyFunctionInternal([&]() {
return VerifyValidity(allocation); },
299 DCHECKFailOption::kDoDCHECKFail)
304 for (
const auto& entry : allocation.resource_quantities()) {
305 if (entry.quantity() == 0) {
308 if (!IsResourceNormalized(entry.resource())) {
315 bool ResourceUtil::IsBoundNormalized(
316 const ResourceAllocation& allocation)
const {
317 DCHECK(IsNormalized(allocation));
318 for (
const auto& entry : allocation.resource_quantities()) {
319 if (!entry.resource().has_device_instance()) {
326 Status ResourceUtil::VerifyFunctionInternal(
327 std::function<Status()> fn, DCHECKFailOption dcheck_fail_option)
const {
328 const Status result = fn();
330 if (dcheck_fail_option == DCHECKFailOption::kDoDCHECKFail) {
331 TF_DCHECK_OK(result);
337 Resource ResourceUtil::NormalizeResource(
const Resource& resource)
const {
338 Resource normalized = resource;
339 if (!normalized.has_device_instance()) {
340 const uint32 num_instances = devices_.find(normalized.device())->second;
341 if (num_instances == 1) {
342 normalized.mutable_device_instance()->set_value(0);
348 bool ResourceUtil::IsResourceNormalized(
const Resource& resource)
const {
349 if (!VerifyFunctionInternal(
350 [&]() {
return VerifyResourceValidity(resource); },
351 DCHECKFailOption::kDoDCHECKFail)
358 return resource.has_device_instance() ||
359 devices_.find(resource.device())->second > 1;
362 void ResourceUtil::AddNormalized(
const ResourceAllocation& to_add,
363 ResourceAllocation* base)
const {
364 DCHECK(IsNormalized(to_add));
365 DCHECK(IsNormalized(*base));
366 for (
const ResourceAllocation::Entry& to_add_entry :
367 to_add.resource_quantities()) {
368 ResourceAllocation::Entry* base_entry =
369 FindOrInsertMutableEntry(to_add_entry.resource(), base);
370 base_entry->set_quantity(base_entry->quantity() + to_add_entry.quantity());
372 DCHECK(IsNormalized(*base));
375 bool ResourceUtil::SubtractNormalized(
const ResourceAllocation& to_subtract,
376 ResourceAllocation* base)
const {
377 DCHECK(IsNormalized(to_subtract));
378 DCHECK(IsNormalized(*base));
381 std::vector<std::pair<ResourceAllocation::Entry*, uint64_t>> new_quantities;
382 for (
const ResourceAllocation::Entry& to_subtract_entry :
383 to_subtract.resource_quantities()) {
384 ResourceAllocation::Entry* base_entry =
385 FindMutableEntry(to_subtract_entry.resource(), base);
386 if (base_entry ==
nullptr ||
387 base_entry->quantity() < to_subtract_entry.quantity()) {
390 const uint64_t new_quantity =
391 base_entry->quantity() - to_subtract_entry.quantity();
392 new_quantities.push_back({base_entry, new_quantity});
394 for (
const auto& new_quantity : new_quantities) {
395 ResourceAllocation::Entry* base_entry = new_quantity.first;
396 const uint64_t quantity = new_quantity.second;
397 base_entry->set_quantity(quantity);
399 *base = Normalize(*base);
403 void ResourceUtil::MultiplyNormalized(uint64_t multiplier,
404 ResourceAllocation* base)
const {
405 DCHECK(IsNormalized(*base));
406 for (
int i = 0; i < base->resource_quantities().size(); ++i) {
407 ResourceAllocation::Entry* entry = base->mutable_resource_quantities(i);
408 entry->set_quantity(entry->quantity() * multiplier);
412 bool ResourceUtil::EqualNormalized(
const ResourceAllocation& lhs,
413 const ResourceAllocation& rhs)
const {
414 if (!VerifyFunctionInternal([&]() {
return VerifyValidity(lhs); },
415 DCHECKFailOption::kDoDCHECKFail)
417 !VerifyFunctionInternal([&]() {
return VerifyValidity(rhs); },
418 DCHECKFailOption::kDoDCHECKFail)
422 DCHECK(IsNormalized(lhs));
423 DCHECK(IsNormalized(rhs));
425 if (lhs.resource_quantities().size() != rhs.resource_quantities().size()) {
429 for (
const ResourceAllocation::Entry& lhs_entry : lhs.resource_quantities()) {
430 bool matched =
false;
431 for (
const ResourceAllocation::Entry& rhs_entry :
432 rhs.resource_quantities()) {
433 if (ResourcesEqual(lhs_entry.resource(), rhs_entry.resource()) &&
434 lhs_entry.quantity() == rhs_entry.quantity()) {
447 bool ResourceUtil::ResourcesEqualNormalized(
const Resource& lhs,
448 const Resource& rhs)
const {
449 if (!VerifyFunctionInternal([&]() {
return VerifyResourceValidity(lhs); },
450 DCHECKFailOption::kDoDCHECKFail)
452 !VerifyFunctionInternal([&]() {
return VerifyResourceValidity(rhs); },
453 DCHECKFailOption::kDoDCHECKFail)
457 DCHECK(IsResourceNormalized(lhs));
458 DCHECK(IsResourceNormalized(rhs));
459 return RawResourcesEqual(lhs, rhs);
462 bool ResourceUtil::LessThanOrEqualNormalized(
463 const ResourceAllocation& lhs,
const ResourceAllocation& rhs)
const {
464 if (!VerifyFunctionInternal([&]() {
return VerifyValidity(lhs); },
465 DCHECKFailOption::kDoDCHECKFail)
467 !VerifyFunctionInternal([&]() {
return VerifyValidity(rhs); },
468 DCHECKFailOption::kDoDCHECKFail)
472 DCHECK(IsNormalized(lhs));
473 DCHECK(IsNormalized(rhs));
475 <<
"LessThanOrEqual() requires the second argument to be bound";
478 ResourceAllocation subtracted_rhs = rhs;
479 for (
const ResourceAllocation::Entry& lhs_entry : lhs.resource_quantities()) {
480 if (lhs_entry.resource().has_device_instance()) {
481 ResourceAllocation to_subtract;
482 *to_subtract.add_resource_quantities() = lhs_entry;
483 if (!Subtract(to_subtract, &subtracted_rhs)) {
491 for (
const ResourceAllocation::Entry& lhs_entry : lhs.resource_quantities()) {
492 if (!lhs_entry.resource().has_device_instance()) {
493 const uint32 num_instances =
494 devices_.find(lhs_entry.resource().device())->second;
495 Resource bound_resource = lhs_entry.resource();
496 bool found_room =
false;
497 for (
int instance = 0; instance < num_instances; ++instance) {
498 bound_resource.mutable_device_instance()->set_value(instance);
499 if (lhs_entry.quantity() <=
500 GetQuantity(bound_resource, subtracted_rhs)) {
513 ResourceAllocation ResourceUtil::OverbindNormalized(
514 const ResourceAllocation& allocation)
const {
515 if (!VerifyFunctionInternal([&]() {
return VerifyValidity(allocation); },
516 DCHECKFailOption::kDoDCHECKFail)
520 DCHECK(IsNormalized(allocation));
522 ResourceAllocation result;
523 for (
const ResourceAllocation::Entry& entry :
524 allocation.resource_quantities()) {
525 if (entry.resource().has_device_instance()) {
526 ResourceAllocation::Entry* result_entry =
527 FindOrInsertMutableEntry(entry.resource(), &result);
528 result_entry->set_quantity(entry.quantity() + result_entry->quantity());
532 const uint32 num_instances =
533 devices_.find(entry.resource().device())->second;
534 Resource bound_resource = entry.resource();
535 for (uint32 instance = 0; instance < num_instances; ++instance) {
536 bound_resource.mutable_device_instance()->set_value(instance);
537 ResourceAllocation::Entry* result_entry =
538 FindOrInsertMutableEntry(bound_resource, &result);
539 result_entry->set_quantity(entry.quantity() + result_entry->quantity());
542 DCHECK(IsNormalized(result));
546 ResourceAllocation ResourceUtil::MaxNormalized(
547 const ResourceAllocation& lhs,
const ResourceAllocation& rhs)
const {
548 DCHECK(IsNormalized(lhs));
549 DCHECK(IsNormalized(rhs));
551 ResourceAllocation max_resource_allocation = rhs;
552 for (
const ResourceAllocation::Entry& lhs_entry : lhs.resource_quantities()) {
553 ResourceAllocation::Entry* max_entry = FindOrInsertMutableEntry(
554 lhs_entry.resource(), &max_resource_allocation);
555 if (lhs_entry.quantity() >= max_entry->quantity()) {
556 max_entry->set_quantity(lhs_entry.quantity());
559 return max_resource_allocation;
562 ResourceAllocation ResourceUtil::MinNormalized(
563 const ResourceAllocation& lhs,
const ResourceAllocation& rhs)
const {
564 DCHECK(IsNormalized(lhs));
565 DCHECK(IsNormalized(rhs));
567 ResourceAllocation min_resource_allocation;
568 ResourceAllocation rhs_copy = rhs;
569 for (
const ResourceAllocation::Entry& lhs_entry : lhs.resource_quantities()) {
570 ResourceAllocation::Entry* rhs_entry =
571 FindMutableEntry(lhs_entry.resource(), &rhs_copy);
572 if (rhs_entry !=
nullptr) {
573 ResourceAllocation::Entry* min_entry =
574 min_resource_allocation.add_resource_quantities();
575 *min_entry->mutable_resource() = lhs_entry.resource();
576 min_entry->set_quantity(
577 std::min(lhs_entry.quantity(), rhs_entry->quantity()));
580 return min_resource_allocation;