OR-Tools  8.2
expressions.cc
Go to the documentation of this file.
1// Copyright 2010-2018 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#include <algorithm>
15#include <cmath>
16#include <memory>
17#include <string>
18#include <utility>
19#include <vector>
20
21#include "absl/container/flat_hash_map.h"
22#include "absl/strings/str_cat.h"
23#include "absl/strings/str_format.h"
32#include "ortools/util/bitset.h"
35
36ABSL_FLAG(bool, cp_disable_expression_optimization, false,
37 "Disable special optimization when creating expressions.");
38ABSL_FLAG(bool, cp_share_int_consts, true,
39 "Share IntConst's with the same value.");
40
41#if defined(_MSC_VER)
42#pragma warning(disable : 4351 4355)
43#endif
44
45namespace operations_research {
46
47// ---------- IntExpr ----------
48
49IntVar* IntExpr::VarWithName(const std::string& name) {
50 IntVar* const var = Var();
51 var->set_name(name);
52 return var;
53}
54
55// ---------- IntVar ----------
56
57IntVar::IntVar(Solver* const s) : IntExpr(s), index_(s->GetNewIntVarIndex()) {}
58
59IntVar::IntVar(Solver* const s, const std::string& name)
60 : IntExpr(s), index_(s->GetNewIntVarIndex()) {
62}
63
64// ----- Boolean variable -----
65
67
69 if (m <= 0) return;
70 if (m > 1) solver()->Fail();
71 SetValue(1);
72}
73
75 if (m >= 1) return;
76 if (m < 0) solver()->Fail();
77 SetValue(0);
78}
79
81 if (mi > 1 || ma < 0 || mi > ma) {
82 solver()->Fail();
83 }
84 if (mi == 1) {
85 SetValue(1);
86 } else if (ma == 0) {
87 SetValue(0);
88 }
89}
90
93 if (v == 0) {
94 SetValue(1);
95 } else if (v == 1) {
96 SetValue(0);
97 }
98 } else if (v == value_) {
99 solver()->Fail();
100 }
101}
102
104 if (u < l) return;
105 if (l <= 0 && u >= 1) {
106 solver()->Fail();
107 } else if (l == 1) {
108 SetValue(0);
109 } else if (u == 0) {
110 SetValue(1);
111 }
112}
113
117 delayed_bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
118 } else {
119 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
120 }
121 }
122}
123
125 return (1 + (value_ == kUnboundBooleanVarValue));
126}
127
129 return ((v == 0 && value_ != 1) || (v == 1 && value_ != 0));
130}
131
133 if (constant > 1 || constant < 0) {
134 return solver()->MakeIntConst(0);
135 }
136 if (constant == 1) {
137 return this;
138 } else { // constant == 0.
139 return solver()->MakeDifference(1, this)->Var();
140 }
141}
142
144 if (constant > 1 || constant < 0) {
145 return solver()->MakeIntConst(1);
146 }
147 if (constant == 1) {
148 return solver()->MakeDifference(1, this)->Var();
149 } else { // constant == 0.
150 return this;
151 }
152}
153
155 if (constant > 1) {
156 return solver()->MakeIntConst(0);
157 } else if (constant <= 0) {
158 return solver()->MakeIntConst(1);
159 } else {
160 return this;
161 }
162}
163
165 if (constant < 0) {
166 return solver()->MakeIntConst(0);
167 } else if (constant >= 1) {
168 return solver()->MakeIntConst(1);
169 } else {
170 return IsEqual(0);
171 }
172}
173
174std::string BooleanVar::DebugString() const {
175 std::string out;
176 const std::string& var_name = name();
177 if (!var_name.empty()) {
178 out = var_name + "(";
179 } else {
180 out = "BooleanVar(";
181 }
182 switch (value_) {
183 case 0:
184 out += "0";
185 break;
186 case 1:
187 out += "1";
188 break;
190 out += "0 .. 1";
191 break;
192 }
193 out += ")";
194 return out;
195}
196
197namespace {
198// ---------- Subclasses of IntVar ----------
199
200// ----- Domain Int Var: base class for variables -----
201// It Contains bounds and a bitset representation of possible values.
202class DomainIntVar : public IntVar {
203 public:
204 // Utility classes
205 class BitSetIterator : public BaseObject {
206 public:
207 BitSetIterator(uint64* const bitset, int64 omin)
208 : bitset_(bitset), omin_(omin), max_(kint64min), current_(kint64max) {}
209
210 ~BitSetIterator() override {}
211
212 void Init(int64 min, int64 max) {
213 max_ = max;
214 current_ = min;
215 }
216
217 bool Ok() const { return current_ <= max_; }
218
219 int64 Value() const { return current_; }
220
221 void Next() {
222 if (++current_ <= max_) {
224 bitset_, current_ - omin_, max_ - omin_) +
225 omin_;
226 }
227 }
228
229 std::string DebugString() const override { return "BitSetIterator"; }
230
231 private:
232 uint64* const bitset_;
233 const int64 omin_;
234 int64 max_;
236 };
237
238 class BitSet : public BaseObject {
239 public:
240 explicit BitSet(Solver* const s) : solver_(s), holes_stamp_(0) {}
241 ~BitSet() override {}
242
243 virtual int64 ComputeNewMin(int64 nmin, int64 cmin, int64 cmax) = 0;
244 virtual int64 ComputeNewMax(int64 nmax, int64 cmin, int64 cmax) = 0;
245 virtual bool Contains(int64 val) const = 0;
246 virtual bool SetValue(int64 val) = 0;
247 virtual bool RemoveValue(int64 val) = 0;
248 virtual uint64 Size() const = 0;
249 virtual void DelayRemoveValue(int64 val) = 0;
250 virtual void ApplyRemovedValues(DomainIntVar* var) = 0;
251 virtual void ClearRemovedValues() = 0;
252 virtual std::string pretty_DebugString(int64 min, int64 max) const = 0;
253 virtual BitSetIterator* MakeIterator() = 0;
254
255 void InitHoles() {
256 const uint64 current_stamp = solver_->stamp();
257 if (holes_stamp_ < current_stamp) {
258 holes_.clear();
259 holes_stamp_ = current_stamp;
260 }
261 }
262
263 virtual void ClearHoles() { holes_.clear(); }
264
265 const std::vector<int64>& Holes() { return holes_; }
266
267 void AddHole(int64 value) { holes_.push_back(value); }
268
269 int NumHoles() const {
270 return holes_stamp_ < solver_->stamp() ? 0 : holes_.size();
271 }
272
273 protected:
274 Solver* const solver_;
275
276 private:
277 std::vector<int64> holes_;
278 uint64 holes_stamp_;
279 };
280
281 class QueueHandler : public Demon {
282 public:
283 explicit QueueHandler(DomainIntVar* const var) : var_(var) {}
284 ~QueueHandler() override {}
285 void Run(Solver* const s) override {
286 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
287 var_->Process();
288 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
289 }
290 Solver::DemonPriority priority() const override {
292 }
293 std::string DebugString() const override {
294 return absl::StrFormat("Handler(%s)", var_->DebugString());
295 }
296
297 private:
298 DomainIntVar* const var_;
299 };
300
301 // Bounds and Value watchers
302
303 // This class stores the watchers variables attached to values. It is
304 // reversible and it helps maintaining the set of 'active' watchers
305 // (variables not bound to a single value).
306 template <class T>
307 class RevIntPtrMap {
308 public:
309 RevIntPtrMap(Solver* const solver, int64 rmin, int64 rmax)
310 : solver_(solver), range_min_(rmin), start_(0) {}
311
312 ~RevIntPtrMap() {}
313
314 bool Empty() const { return start_.Value() == elements_.size(); }
315
316 void SortActive() { std::sort(elements_.begin(), elements_.end()); }
317
318 // Access with value API.
319
320 // Add the pointer to the map attached to the given value.
321 void UnsafeRevInsert(int64 value, T* elem) {
322 elements_.push_back(std::make_pair(value, elem));
323 if (solver_->state() != Solver::OUTSIDE_SEARCH) {
324 solver_->AddBacktrackAction(
325 [this, value](Solver* s) { Uninsert(value); }, false);
326 }
327 }
328
329 T* FindPtrOrNull(int64 value, int* position) {
330 for (int pos = start_.Value(); pos < elements_.size(); ++pos) {
331 if (elements_[pos].first == value) {
332 if (position != nullptr) *position = pos;
333 return At(pos).second;
334 }
335 }
336 return nullptr;
337 }
338
339 // Access map through the underlying vector.
340 void RemoveAt(int position) {
341 const int start = start_.Value();
342 DCHECK_GE(position, start);
343 DCHECK_LT(position, elements_.size());
344 if (position > start) {
345 // Swap the current element with the one at the start position, and
346 // increase start.
347 const std::pair<int64, T*> copy = elements_[start];
348 elements_[start] = elements_[position];
349 elements_[position] = copy;
350 }
351 start_.Incr(solver_);
352 }
353
354 const std::pair<int64, T*>& At(int position) const {
355 DCHECK_GE(position, start_.Value());
356 DCHECK_LT(position, elements_.size());
357 return elements_[position];
358 }
359
360 void RemoveAll() { start_.SetValue(solver_, elements_.size()); }
361
362 int start() const { return start_.Value(); }
363 int end() const { return elements_.size(); }
364 // Number of active elements.
365 int Size() const { return elements_.size() - start_.Value(); }
366
367 // Removes the object permanently from the map.
368 void Uninsert(int64 value) {
369 for (int pos = 0; pos < elements_.size(); ++pos) {
370 if (elements_[pos].first == value) {
371 DCHECK_GE(pos, start_.Value());
372 const int last = elements_.size() - 1;
373 if (pos != last) { // Swap the current with the last.
374 elements_[pos] = elements_.back();
375 }
376 elements_.pop_back();
377 return;
378 }
379 }
380 LOG(FATAL) << "The element should have been removed";
381 }
382
383 private:
384 Solver* const solver_;
385 const int64 range_min_;
386 NumericalRev<int> start_;
387 std::vector<std::pair<int64, T*>> elements_;
388 };
389
390 // Base class for value watchers
391 class BaseValueWatcher : public Constraint {
392 public:
393 explicit BaseValueWatcher(Solver* const solver) : Constraint(solver) {}
394
395 ~BaseValueWatcher() override {}
396
397 virtual IntVar* GetOrMakeValueWatcher(int64 value) = 0;
398
399 virtual void SetValueWatcher(IntVar* const boolvar, int64 value) = 0;
400 };
401
402 // This class monitors the domain of the variable and updates the
403 // IsEqual/IsDifferent boolean variables accordingly.
404 class ValueWatcher : public BaseValueWatcher {
405 public:
406 class WatchDemon : public Demon {
407 public:
408 WatchDemon(ValueWatcher* const watcher, int64 value, IntVar* var)
409 : value_watcher_(watcher), value_(value), var_(var) {}
410 ~WatchDemon() override {}
411
412 void Run(Solver* const solver) override {
413 value_watcher_->ProcessValueWatcher(value_, var_);
414 }
415
416 private:
417 ValueWatcher* const value_watcher_;
418 const int64 value_;
419 IntVar* const var_;
420 };
421
422 class VarDemon : public Demon {
423 public:
424 explicit VarDemon(ValueWatcher* const watcher)
425 : value_watcher_(watcher) {}
426
427 ~VarDemon() override {}
428
429 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
430
431 private:
432 ValueWatcher* const value_watcher_;
433 };
434
435 ValueWatcher(Solver* const solver, DomainIntVar* const variable)
436 : BaseValueWatcher(solver),
437 variable_(variable),
438 hole_iterator_(variable_->MakeHoleIterator(true)),
439 var_demon_(nullptr),
440 watchers_(solver, variable->Min(), variable->Max()) {}
441
442 ~ValueWatcher() override {}
443
444 IntVar* GetOrMakeValueWatcher(int64 value) override {
445 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
446 if (watcher != nullptr) return watcher;
447 if (variable_->Contains(value)) {
448 if (variable_->Bound()) {
449 return solver()->MakeIntConst(1);
450 } else {
451 const std::string vname = variable_->HasName()
452 ? variable_->name()
453 : variable_->DebugString();
454 const std::string bname =
455 absl::StrFormat("Watch<%s == %d>", vname, value);
456 IntVar* const boolvar = solver()->MakeBoolVar(bname);
457 watchers_.UnsafeRevInsert(value, boolvar);
458 if (posted_.Switched()) {
459 boolvar->WhenBound(
460 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
461 var_demon_->desinhibit(solver());
462 }
463 return boolvar;
464 }
465 } else {
466 return variable_->solver()->MakeIntConst(0);
467 }
468 }
469
470 void SetValueWatcher(IntVar* const boolvar, int64 value) override {
471 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
472 if (!boolvar->Bound()) {
473 watchers_.UnsafeRevInsert(value, boolvar);
474 if (posted_.Switched() && !boolvar->Bound()) {
475 boolvar->WhenBound(
476 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
477 var_demon_->desinhibit(solver());
478 }
479 }
480 }
481
482 void Post() override {
483 var_demon_ = solver()->RevAlloc(new VarDemon(this));
484 variable_->WhenDomain(var_demon_);
485 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
486 const std::pair<int64, IntVar*>& w = watchers_.At(pos);
487 const int64 value = w.first;
488 IntVar* const boolvar = w.second;
489 if (!boolvar->Bound() && variable_->Contains(value)) {
490 boolvar->WhenBound(
491 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
492 }
493 }
494 posted_.Switch(solver());
495 }
496
497 void InitialPropagate() override {
498 if (variable_->Bound()) {
499 VariableBound();
500 } else {
501 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
502 const std::pair<int64, IntVar*>& w = watchers_.At(pos);
503 const int64 value = w.first;
504 IntVar* const boolvar = w.second;
505 if (!variable_->Contains(value)) {
506 boolvar->SetValue(0);
507 watchers_.RemoveAt(pos);
508 } else {
509 if (boolvar->Bound()) {
510 ProcessValueWatcher(value, boolvar);
511 watchers_.RemoveAt(pos);
512 }
513 }
514 }
515 CheckInhibit();
516 }
517 }
518
519 void ProcessValueWatcher(int64 value, IntVar* boolvar) {
520 if (boolvar->Min() == 0) {
521 if (variable_->Size() < 0xFFFFFF) {
522 variable_->RemoveValue(value);
523 } else {
524 // Delay removal.
525 solver()->AddConstraint(solver()->MakeNonEquality(variable_, value));
526 }
527 } else {
528 variable_->SetValue(value);
529 }
530 }
531
532 void ProcessVar() {
533 const int kSmallList = 16;
534 if (variable_->Bound()) {
535 VariableBound();
536 } else if (watchers_.Size() <= kSmallList ||
537 variable_->Min() != variable_->OldMin() ||
538 variable_->Max() != variable_->OldMax()) {
539 // Brute force loop for small numbers of watchers, or if the bounds have
540 // changed, which would have required a sort (n log(n)) anyway to take
541 // advantage of.
542 ScanWatchers();
543 CheckInhibit();
544 } else {
545 // If there is no bitset, then there are no holes.
546 // In that case, the two loops above should have performed all
547 // propagation. Otherwise, scan the remaining watchers.
548 BitSet* const bitset = variable_->bitset();
549 if (bitset != nullptr && !watchers_.Empty()) {
550 if (bitset->NumHoles() * 2 < watchers_.Size()) {
551 for (const int64 hole : InitAndGetValues(hole_iterator_)) {
552 int pos = 0;
553 IntVar* const boolvar = watchers_.FindPtrOrNull(hole, &pos);
554 if (boolvar != nullptr) {
555 boolvar->SetValue(0);
556 watchers_.RemoveAt(pos);
557 }
558 }
559 } else {
560 ScanWatchers();
561 }
562 }
563 CheckInhibit();
564 }
565 }
566
567 // Optimized case if the variable is bound.
568 void VariableBound() {
569 DCHECK(variable_->Bound());
570 const int64 value = variable_->Min();
571 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
572 const std::pair<int64, IntVar*>& w = watchers_.At(pos);
573 w.second->SetValue(w.first == value);
574 }
575 watchers_.RemoveAll();
576 var_demon_->inhibit(solver());
577 }
578
579 // Scans all the watchers to check and assign them.
580 void ScanWatchers() {
581 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
582 const std::pair<int64, IntVar*>& w = watchers_.At(pos);
583 if (!variable_->Contains(w.first)) {
584 IntVar* const boolvar = w.second;
585 boolvar->SetValue(0);
586 watchers_.RemoveAt(pos);
587 }
588 }
589 }
590
591 // If the set of active watchers is empty, we can inhibit the demon on the
592 // main variable.
593 void CheckInhibit() {
594 if (watchers_.Empty()) {
595 var_demon_->inhibit(solver());
596 }
597 }
598
599 void Accept(ModelVisitor* const visitor) const override {
600 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
601 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
602 variable_);
603 std::vector<int64> all_coefficients;
604 std::vector<IntVar*> all_bool_vars;
605 for (int position = watchers_.start(); position < watchers_.end();
606 ++position) {
607 const std::pair<int64, IntVar*>& w = watchers_.At(position);
608 all_coefficients.push_back(w.first);
609 all_bool_vars.push_back(w.second);
610 }
611 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
612 all_bool_vars);
613 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
614 all_coefficients);
615 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
616 }
617
618 std::string DebugString() const override {
619 return absl::StrFormat("ValueWatcher(%s)", variable_->DebugString());
620 }
621
622 private:
623 DomainIntVar* const variable_;
624 IntVarIterator* const hole_iterator_;
625 RevSwitch posted_;
626 Demon* var_demon_;
627 RevIntPtrMap<IntVar> watchers_;
628 };
629
630 // Optimized case for small maps.
631 class DenseValueWatcher : public BaseValueWatcher {
632 public:
633 class WatchDemon : public Demon {
634 public:
635 WatchDemon(DenseValueWatcher* const watcher, int64 value, IntVar* var)
636 : value_watcher_(watcher), value_(value), var_(var) {}
637 ~WatchDemon() override {}
638
639 void Run(Solver* const solver) override {
640 value_watcher_->ProcessValueWatcher(value_, var_);
641 }
642
643 private:
644 DenseValueWatcher* const value_watcher_;
645 const int64 value_;
646 IntVar* const var_;
647 };
648
649 class VarDemon : public Demon {
650 public:
651 explicit VarDemon(DenseValueWatcher* const watcher)
652 : value_watcher_(watcher) {}
653
654 ~VarDemon() override {}
655
656 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
657
658 private:
659 DenseValueWatcher* const value_watcher_;
660 };
661
662 DenseValueWatcher(Solver* const solver, DomainIntVar* const variable)
663 : BaseValueWatcher(solver),
664 variable_(variable),
665 hole_iterator_(variable_->MakeHoleIterator(true)),
666 var_demon_(nullptr),
667 offset_(variable->Min()),
668 watchers_(variable->Max() - variable->Min() + 1, nullptr),
669 active_watchers_(0) {}
670
671 ~DenseValueWatcher() override {}
672
673 IntVar* GetOrMakeValueWatcher(int64 value) override {
674 const int64 var_max = offset_ + watchers_.size() - 1; // Bad cast.
675 if (value < offset_ || value > var_max) {
676 return solver()->MakeIntConst(0);
677 }
678 const int index = value - offset_;
679 IntVar* const watcher = watchers_[index];
680 if (watcher != nullptr) return watcher;
681 if (variable_->Contains(value)) {
682 if (variable_->Bound()) {
683 return solver()->MakeIntConst(1);
684 } else {
685 const std::string vname = variable_->HasName()
686 ? variable_->name()
687 : variable_->DebugString();
688 const std::string bname =
689 absl::StrFormat("Watch<%s == %d>", vname, value);
690 IntVar* const boolvar = solver()->MakeBoolVar(bname);
691 RevInsert(index, boolvar);
692 if (posted_.Switched()) {
693 boolvar->WhenBound(
694 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
695 var_demon_->desinhibit(solver());
696 }
697 return boolvar;
698 }
699 } else {
700 return variable_->solver()->MakeIntConst(0);
701 }
702 }
703
704 void SetValueWatcher(IntVar* const boolvar, int64 value) override {
705 const int index = value - offset_;
706 CHECK(watchers_[index] == nullptr);
707 if (!boolvar->Bound()) {
708 RevInsert(index, boolvar);
709 if (posted_.Switched() && !boolvar->Bound()) {
710 boolvar->WhenBound(
711 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
712 var_demon_->desinhibit(solver());
713 }
714 }
715 }
716
717 void Post() override {
718 var_demon_ = solver()->RevAlloc(new VarDemon(this));
719 variable_->WhenDomain(var_demon_);
720 for (int pos = 0; pos < watchers_.size(); ++pos) {
721 const int64 value = pos + offset_;
722 IntVar* const boolvar = watchers_[pos];
723 if (boolvar != nullptr && !boolvar->Bound() &&
724 variable_->Contains(value)) {
725 boolvar->WhenBound(
726 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
727 }
728 }
729 posted_.Switch(solver());
730 }
731
732 void InitialPropagate() override {
733 if (variable_->Bound()) {
734 VariableBound();
735 } else {
736 for (int pos = 0; pos < watchers_.size(); ++pos) {
737 IntVar* const boolvar = watchers_[pos];
738 if (boolvar == nullptr) continue;
739 const int64 value = pos + offset_;
740 if (!variable_->Contains(value)) {
741 boolvar->SetValue(0);
742 RevRemove(pos);
743 } else if (boolvar->Bound()) {
744 ProcessValueWatcher(value, boolvar);
745 RevRemove(pos);
746 }
747 }
748 if (active_watchers_.Value() == 0) {
749 var_demon_->inhibit(solver());
750 }
751 }
752 }
753
754 void ProcessValueWatcher(int64 value, IntVar* boolvar) {
755 if (boolvar->Min() == 0) {
756 variable_->RemoveValue(value);
757 } else {
758 variable_->SetValue(value);
759 }
760 }
761
762 void ProcessVar() {
763 if (variable_->Bound()) {
764 VariableBound();
765 } else {
766 // Brute force loop for small numbers of watchers.
767 ScanWatchers();
768 if (active_watchers_.Value() == 0) {
769 var_demon_->inhibit(solver());
770 }
771 }
772 }
773
774 // Optimized case if the variable is bound.
775 void VariableBound() {
776 DCHECK(variable_->Bound());
777 const int64 value = variable_->Min();
778 for (int pos = 0; pos < watchers_.size(); ++pos) {
779 IntVar* const boolvar = watchers_[pos];
780 if (boolvar != nullptr) {
781 boolvar->SetValue(pos + offset_ == value);
782 RevRemove(pos);
783 }
784 }
785 var_demon_->inhibit(solver());
786 }
787
788 // Scans all the watchers to check and assign them.
789 void ScanWatchers() {
790 const int64 old_min_index = variable_->OldMin() - offset_;
791 const int64 old_max_index = variable_->OldMax() - offset_;
792 const int64 min_index = variable_->Min() - offset_;
793 const int64 max_index = variable_->Max() - offset_;
794 for (int pos = old_min_index; pos < min_index; ++pos) {
795 IntVar* const boolvar = watchers_[pos];
796 if (boolvar != nullptr) {
797 boolvar->SetValue(0);
798 RevRemove(pos);
799 }
800 }
801 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
802 IntVar* const boolvar = watchers_[pos];
803 if (boolvar != nullptr) {
804 boolvar->SetValue(0);
805 RevRemove(pos);
806 }
807 }
808 BitSet* const bitset = variable_->bitset();
809 if (bitset != nullptr) {
810 if (bitset->NumHoles() * 2 < active_watchers_.Value()) {
811 for (const int64 hole : InitAndGetValues(hole_iterator_)) {
812 IntVar* const boolvar = watchers_[hole - offset_];
813 if (boolvar != nullptr) {
814 boolvar->SetValue(0);
815 RevRemove(hole - offset_);
816 }
817 }
818 } else {
819 for (int pos = min_index + 1; pos < max_index; ++pos) {
820 IntVar* const boolvar = watchers_[pos];
821 if (boolvar != nullptr && !variable_->Contains(offset_ + pos)) {
822 boolvar->SetValue(0);
823 RevRemove(pos);
824 }
825 }
826 }
827 }
828 }
829
830 void RevRemove(int pos) {
831 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
832 watchers_[pos] = nullptr;
833 active_watchers_.Decr(solver());
834 }
835
836 void RevInsert(int pos, IntVar* boolvar) {
837 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
838 watchers_[pos] = boolvar;
839 active_watchers_.Incr(solver());
840 }
841
842 void Accept(ModelVisitor* const visitor) const override {
843 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
844 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
845 variable_);
846 std::vector<int64> all_coefficients;
847 std::vector<IntVar*> all_bool_vars;
848 for (int position = 0; position < watchers_.size(); ++position) {
849 if (watchers_[position] != nullptr) {
850 all_coefficients.push_back(position + offset_);
851 all_bool_vars.push_back(watchers_[position]);
852 }
853 }
854 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
855 all_bool_vars);
856 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
857 all_coefficients);
858 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
859 }
860
861 std::string DebugString() const override {
862 return absl::StrFormat("DenseValueWatcher(%s)", variable_->DebugString());
863 }
864
865 private:
866 DomainIntVar* const variable_;
867 IntVarIterator* const hole_iterator_;
868 RevSwitch posted_;
869 Demon* var_demon_;
870 const int64 offset_;
871 std::vector<IntVar*> watchers_;
872 NumericalRev<int> active_watchers_;
873 };
874
875 class BaseUpperBoundWatcher : public Constraint {
876 public:
877 explicit BaseUpperBoundWatcher(Solver* const solver) : Constraint(solver) {}
878
879 ~BaseUpperBoundWatcher() override {}
880
881 virtual IntVar* GetOrMakeUpperBoundWatcher(int64 value) = 0;
882
883 virtual void SetUpperBoundWatcher(IntVar* const boolvar, int64 value) = 0;
884 };
885
886 // This class watches the bounds of the variable and updates the
887 // IsGreater/IsGreaterOrEqual/IsLess/IsLessOrEqual demons
888 // accordingly.
889 class UpperBoundWatcher : public BaseUpperBoundWatcher {
890 public:
891 class WatchDemon : public Demon {
892 public:
893 WatchDemon(UpperBoundWatcher* const watcher, int64 index,
894 IntVar* const var)
895 : value_watcher_(watcher), index_(index), var_(var) {}
896 ~WatchDemon() override {}
897
898 void Run(Solver* const solver) override {
899 value_watcher_->ProcessUpperBoundWatcher(index_, var_);
900 }
901
902 private:
903 UpperBoundWatcher* const value_watcher_;
904 const int64 index_;
905 IntVar* const var_;
906 };
907
908 class VarDemon : public Demon {
909 public:
910 explicit VarDemon(UpperBoundWatcher* const watcher)
911 : value_watcher_(watcher) {}
912 ~VarDemon() override {}
913
914 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
915
916 private:
917 UpperBoundWatcher* const value_watcher_;
918 };
919
920 UpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
921 : BaseUpperBoundWatcher(solver),
922 variable_(variable),
923 var_demon_(nullptr),
924 watchers_(solver, variable->Min(), variable->Max()),
925 start_(0),
926 end_(0),
927 sorted_(false) {}
928
929 ~UpperBoundWatcher() override {}
930
931 IntVar* GetOrMakeUpperBoundWatcher(int64 value) override {
932 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
933 if (watcher != nullptr) {
934 return watcher;
935 }
936 if (variable_->Max() >= value) {
937 if (variable_->Min() >= value) {
938 return solver()->MakeIntConst(1);
939 } else {
940 const std::string vname = variable_->HasName()
941 ? variable_->name()
942 : variable_->DebugString();
943 const std::string bname =
944 absl::StrFormat("Watch<%s >= %d>", vname, value);
945 IntVar* const boolvar = solver()->MakeBoolVar(bname);
946 watchers_.UnsafeRevInsert(value, boolvar);
947 if (posted_.Switched()) {
948 boolvar->WhenBound(
949 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
950 var_demon_->desinhibit(solver());
951 sorted_ = false;
952 }
953 return boolvar;
954 }
955 } else {
956 return variable_->solver()->MakeIntConst(0);
957 }
958 }
959
960 void SetUpperBoundWatcher(IntVar* const boolvar, int64 value) override {
961 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
962 watchers_.UnsafeRevInsert(value, boolvar);
963 if (posted_.Switched() && !boolvar->Bound()) {
964 boolvar->WhenBound(
965 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
966 var_demon_->desinhibit(solver());
967 sorted_ = false;
968 }
969 }
970
971 void Post() override {
972 const int kTooSmallToSort = 8;
973 var_demon_ = solver()->RevAlloc(new VarDemon(this));
974 variable_->WhenRange(var_demon_);
975
976 if (watchers_.Size() > kTooSmallToSort) {
977 watchers_.SortActive();
978 sorted_ = true;
979 start_.SetValue(solver(), watchers_.start());
980 end_.SetValue(solver(), watchers_.end() - 1);
981 }
982
983 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
984 const std::pair<int64, IntVar*>& w = watchers_.At(pos);
985 IntVar* const boolvar = w.second;
986 const int64 value = w.first;
987 if (!boolvar->Bound() && value > variable_->Min() &&
988 value <= variable_->Max()) {
989 boolvar->WhenBound(
990 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
991 }
992 }
993 posted_.Switch(solver());
994 }
995
996 void InitialPropagate() override {
997 const int64 var_min = variable_->Min();
998 const int64 var_max = variable_->Max();
999 if (sorted_) {
1000 while (start_.Value() <= end_.Value()) {
1001 const std::pair<int64, IntVar*>& w = watchers_.At(start_.Value());
1002 if (w.first <= var_min) {
1003 w.second->SetValue(1);
1004 start_.Incr(solver());
1005 } else {
1006 break;
1007 }
1008 }
1009 while (end_.Value() >= start_.Value()) {
1010 const std::pair<int64, IntVar*>& w = watchers_.At(end_.Value());
1011 if (w.first > var_max) {
1012 w.second->SetValue(0);
1013 end_.Decr(solver());
1014 } else {
1015 break;
1016 }
1017 }
1018 for (int i = start_.Value(); i <= end_.Value(); ++i) {
1019 const std::pair<int64, IntVar*>& w = watchers_.At(i);
1020 if (w.second->Bound()) {
1021 ProcessUpperBoundWatcher(w.first, w.second);
1022 }
1023 }
1024 if (start_.Value() > end_.Value()) {
1025 var_demon_->inhibit(solver());
1026 }
1027 } else {
1028 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1029 const std::pair<int64, IntVar*>& w = watchers_.At(pos);
1030 const int64 value = w.first;
1031 IntVar* const boolvar = w.second;
1032
1033 if (value <= var_min) {
1034 boolvar->SetValue(1);
1035 watchers_.RemoveAt(pos);
1036 } else if (value > var_max) {
1037 boolvar->SetValue(0);
1038 watchers_.RemoveAt(pos);
1039 } else if (boolvar->Bound()) {
1040 ProcessUpperBoundWatcher(value, boolvar);
1041 watchers_.RemoveAt(pos);
1042 }
1043 }
1044 }
1045 }
1046
1047 void Accept(ModelVisitor* const visitor) const override {
1048 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1049 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1050 variable_);
1051 std::vector<int64> all_coefficients;
1052 std::vector<IntVar*> all_bool_vars;
1053 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1054 const std::pair<int64, IntVar*>& w = watchers_.At(pos);
1055 all_coefficients.push_back(w.first);
1056 all_bool_vars.push_back(w.second);
1057 }
1058 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1059 all_bool_vars);
1060 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1061 all_coefficients);
1062 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1063 }
1064
1065 std::string DebugString() const override {
1066 return absl::StrFormat("UpperBoundWatcher(%s)", variable_->DebugString());
1067 }
1068
1069 private:
1070 void ProcessUpperBoundWatcher(int64 value, IntVar* const boolvar) {
1071 if (boolvar->Min() == 0) {
1072 variable_->SetMax(value - 1);
1073 } else {
1074 variable_->SetMin(value);
1075 }
1076 }
1077
1078 void ProcessVar() {
1079 const int64 var_min = variable_->Min();
1080 const int64 var_max = variable_->Max();
1081 if (sorted_) {
1082 while (start_.Value() <= end_.Value()) {
1083 const std::pair<int64, IntVar*>& w = watchers_.At(start_.Value());
1084 if (w.first <= var_min) {
1085 w.second->SetValue(1);
1086 start_.Incr(solver());
1087 } else {
1088 break;
1089 }
1090 }
1091 while (end_.Value() >= start_.Value()) {
1092 const std::pair<int64, IntVar*>& w = watchers_.At(end_.Value());
1093 if (w.first > var_max) {
1094 w.second->SetValue(0);
1095 end_.Decr(solver());
1096 } else {
1097 break;
1098 }
1099 }
1100 if (start_.Value() > end_.Value()) {
1101 var_demon_->inhibit(solver());
1102 }
1103 } else {
1104 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1105 const std::pair<int64, IntVar*>& w = watchers_.At(pos);
1106 const int64 value = w.first;
1107 IntVar* const boolvar = w.second;
1108
1109 if (value <= var_min) {
1110 boolvar->SetValue(1);
1111 watchers_.RemoveAt(pos);
1112 } else if (value > var_max) {
1113 boolvar->SetValue(0);
1114 watchers_.RemoveAt(pos);
1115 }
1116 }
1117 if (watchers_.Empty()) {
1118 var_demon_->inhibit(solver());
1119 }
1120 }
1121 }
1122
1123 DomainIntVar* const variable_;
1124 RevSwitch posted_;
1125 Demon* var_demon_;
1126 RevIntPtrMap<IntVar> watchers_;
1127 NumericalRev<int> start_;
1128 NumericalRev<int> end_;
1129 bool sorted_;
1130 };
1131
1132 // Optimized case for small maps.
1133 class DenseUpperBoundWatcher : public BaseUpperBoundWatcher {
1134 public:
1135 class WatchDemon : public Demon {
1136 public:
1137 WatchDemon(DenseUpperBoundWatcher* const watcher, int64 value,
1138 IntVar* var)
1139 : value_watcher_(watcher), value_(value), var_(var) {}
1140 ~WatchDemon() override {}
1141
1142 void Run(Solver* const solver) override {
1143 value_watcher_->ProcessUpperBoundWatcher(value_, var_);
1144 }
1145
1146 private:
1147 DenseUpperBoundWatcher* const value_watcher_;
1148 const int64 value_;
1149 IntVar* const var_;
1150 };
1151
1152 class VarDemon : public Demon {
1153 public:
1154 explicit VarDemon(DenseUpperBoundWatcher* const watcher)
1155 : value_watcher_(watcher) {}
1156
1157 ~VarDemon() override {}
1158
1159 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
1160
1161 private:
1162 DenseUpperBoundWatcher* const value_watcher_;
1163 };
1164
1165 DenseUpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
1166 : BaseUpperBoundWatcher(solver),
1167 variable_(variable),
1168 var_demon_(nullptr),
1169 offset_(variable->Min()),
1170 watchers_(variable->Max() - variable->Min() + 1, nullptr),
1171 active_watchers_(0) {}
1172
1173 ~DenseUpperBoundWatcher() override {}
1174
1175 IntVar* GetOrMakeUpperBoundWatcher(int64 value) override {
1176 if (variable_->Max() >= value) {
1177 if (variable_->Min() >= value) {
1178 return solver()->MakeIntConst(1);
1179 } else {
1180 const std::string vname = variable_->HasName()
1181 ? variable_->name()
1182 : variable_->DebugString();
1183 const std::string bname =
1184 absl::StrFormat("Watch<%s >= %d>", vname, value);
1185 IntVar* const boolvar = solver()->MakeBoolVar(bname);
1186 RevInsert(value - offset_, boolvar);
1187 if (posted_.Switched()) {
1188 boolvar->WhenBound(
1189 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1190 var_demon_->desinhibit(solver());
1191 }
1192 return boolvar;
1193 }
1194 } else {
1195 return variable_->solver()->MakeIntConst(0);
1196 }
1197 }
1198
1199 void SetUpperBoundWatcher(IntVar* const boolvar, int64 value) override {
1200 const int index = value - offset_;
1201 CHECK(watchers_[index] == nullptr);
1202 if (!boolvar->Bound()) {
1203 RevInsert(index, boolvar);
1204 if (posted_.Switched() && !boolvar->Bound()) {
1205 boolvar->WhenBound(
1206 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1207 var_demon_->desinhibit(solver());
1208 }
1209 }
1210 }
1211
1212 void Post() override {
1213 var_demon_ = solver()->RevAlloc(new VarDemon(this));
1214 variable_->WhenRange(var_demon_);
1215 for (int pos = 0; pos < watchers_.size(); ++pos) {
1216 const int64 value = pos + offset_;
1217 IntVar* const boolvar = watchers_[pos];
1218 if (boolvar != nullptr && !boolvar->Bound() &&
1219 value > variable_->Min() && value <= variable_->Max()) {
1220 boolvar->WhenBound(
1221 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1222 }
1223 }
1224 posted_.Switch(solver());
1225 }
1226
1227 void InitialPropagate() override {
1228 for (int pos = 0; pos < watchers_.size(); ++pos) {
1229 IntVar* const boolvar = watchers_[pos];
1230 if (boolvar == nullptr) continue;
1231 const int64 value = pos + offset_;
1232 if (value <= variable_->Min()) {
1233 boolvar->SetValue(1);
1234 RevRemove(pos);
1235 } else if (value > variable_->Max()) {
1236 boolvar->SetValue(0);
1237 RevRemove(pos);
1238 } else if (boolvar->Bound()) {
1239 ProcessUpperBoundWatcher(value, boolvar);
1240 RevRemove(pos);
1241 }
1242 }
1243 if (active_watchers_.Value() == 0) {
1244 var_demon_->inhibit(solver());
1245 }
1246 }
1247
1248 void ProcessUpperBoundWatcher(int64 value, IntVar* boolvar) {
1249 if (boolvar->Min() == 0) {
1250 variable_->SetMax(value - 1);
1251 } else {
1252 variable_->SetMin(value);
1253 }
1254 }
1255
1256 void ProcessVar() {
1257 const int64 old_min_index = variable_->OldMin() - offset_;
1258 const int64 old_max_index = variable_->OldMax() - offset_;
1259 const int64 min_index = variable_->Min() - offset_;
1260 const int64 max_index = variable_->Max() - offset_;
1261 for (int pos = old_min_index; pos <= min_index; ++pos) {
1262 IntVar* const boolvar = watchers_[pos];
1263 if (boolvar != nullptr) {
1264 boolvar->SetValue(1);
1265 RevRemove(pos);
1266 }
1267 }
1268
1269 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
1270 IntVar* const boolvar = watchers_[pos];
1271 if (boolvar != nullptr) {
1272 boolvar->SetValue(0);
1273 RevRemove(pos);
1274 }
1275 }
1276 if (active_watchers_.Value() == 0) {
1277 var_demon_->inhibit(solver());
1278 }
1279 }
1280
1281 void RevRemove(int pos) {
1282 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1283 watchers_[pos] = nullptr;
1284 active_watchers_.Decr(solver());
1285 }
1286
1287 void RevInsert(int pos, IntVar* boolvar) {
1288 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1289 watchers_[pos] = boolvar;
1290 active_watchers_.Incr(solver());
1291 }
1292
1293 void Accept(ModelVisitor* const visitor) const override {
1294 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1295 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1296 variable_);
1297 std::vector<int64> all_coefficients;
1298 std::vector<IntVar*> all_bool_vars;
1299 for (int position = 0; position < watchers_.size(); ++position) {
1300 if (watchers_[position] != nullptr) {
1301 all_coefficients.push_back(position + offset_);
1302 all_bool_vars.push_back(watchers_[position]);
1303 }
1304 }
1305 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1306 all_bool_vars);
1307 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1308 all_coefficients);
1309 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1310 }
1311
1312 std::string DebugString() const override {
1313 return absl::StrFormat("DenseUpperBoundWatcher(%s)",
1314 variable_->DebugString());
1315 }
1316
1317 private:
1318 DomainIntVar* const variable_;
1319 RevSwitch posted_;
1320 Demon* var_demon_;
1321 const int64 offset_;
1322 std::vector<IntVar*> watchers_;
1323 NumericalRev<int> active_watchers_;
1324 };
1325
1326 // ----- Main Class -----
1327 DomainIntVar(Solver* const s, int64 vmin, int64 vmax,
1328 const std::string& name);
1329 DomainIntVar(Solver* const s, const std::vector<int64>& sorted_values,
1330 const std::string& name);
1331 ~DomainIntVar() override;
1332
1333 int64 Min() const override { return min_.Value(); }
1334 void SetMin(int64 m) override;
1335 int64 Max() const override { return max_.Value(); }
1336 void SetMax(int64 m) override;
1337 void SetRange(int64 mi, int64 ma) override;
1338 void SetValue(int64 v) override;
1339 bool Bound() const override { return (min_.Value() == max_.Value()); }
1340 int64 Value() const override {
1341 CHECK_EQ(min_.Value(), max_.Value())
1342 << " variable " << DebugString() << " is not bound.";
1343 return min_.Value();
1344 }
1345 void RemoveValue(int64 v) override;
1346 void RemoveInterval(int64 l, int64 u) override;
1347 void CreateBits();
1348 void WhenBound(Demon* d) override {
1349 if (min_.Value() != max_.Value()) {
1350 if (d->priority() == Solver::DELAYED_PRIORITY) {
1351 delayed_bound_demons_.PushIfNotTop(solver(),
1352 solver()->RegisterDemon(d));
1353 } else {
1354 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1355 }
1356 }
1357 }
1358 void WhenRange(Demon* d) override {
1359 if (min_.Value() != max_.Value()) {
1360 if (d->priority() == Solver::DELAYED_PRIORITY) {
1361 delayed_range_demons_.PushIfNotTop(solver(),
1362 solver()->RegisterDemon(d));
1363 } else {
1364 range_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1365 }
1366 }
1367 }
1368 void WhenDomain(Demon* d) override {
1369 if (min_.Value() != max_.Value()) {
1370 if (d->priority() == Solver::DELAYED_PRIORITY) {
1371 delayed_domain_demons_.PushIfNotTop(solver(),
1372 solver()->RegisterDemon(d));
1373 } else {
1374 domain_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1375 }
1376 }
1377 }
1378
1379 IntVar* IsEqual(int64 constant) override {
1380 Solver* const s = solver();
1381 if (constant == min_.Value() && value_watcher_ == nullptr) {
1382 return s->MakeIsLessOrEqualCstVar(this, constant);
1383 }
1384 if (constant == max_.Value() && value_watcher_ == nullptr) {
1385 return s->MakeIsGreaterOrEqualCstVar(this, constant);
1386 }
1387 if (!Contains(constant)) {
1388 return s->MakeIntConst(int64{0});
1389 }
1390 if (Bound() && min_.Value() == constant) {
1391 return s->MakeIntConst(int64{1});
1392 }
1393 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1394 this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1395 if (cache != nullptr) {
1396 return cache->Var();
1397 } else {
1398 if (value_watcher_ == nullptr) {
1399 if (CapSub(Max(), Min()) <= 256) {
1400 solver()->SaveAndSetValue(
1401 reinterpret_cast<void**>(&value_watcher_),
1402 reinterpret_cast<void*>(
1403 solver()->RevAlloc(new DenseValueWatcher(solver(), this))));
1404
1405 } else {
1406 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1407 reinterpret_cast<void*>(solver()->RevAlloc(
1408 new ValueWatcher(solver(), this))));
1409 }
1410 solver()->AddConstraint(value_watcher_);
1411 }
1412 IntVar* const boolvar = value_watcher_->GetOrMakeValueWatcher(constant);
1413 s->Cache()->InsertExprConstantExpression(
1414 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1415 return boolvar;
1416 }
1417 }
1418
1419 Constraint* SetIsEqual(const std::vector<int64>& values,
1420 const std::vector<IntVar*>& vars) {
1421 if (value_watcher_ == nullptr) {
1422 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1423 reinterpret_cast<void*>(solver()->RevAlloc(
1424 new ValueWatcher(solver(), this))));
1425 for (int i = 0; i < vars.size(); ++i) {
1426 value_watcher_->SetValueWatcher(vars[i], values[i]);
1427 }
1428 }
1429 return value_watcher_;
1430 }
1431
1432 IntVar* IsDifferent(int64 constant) override {
1433 Solver* const s = solver();
1434 if (constant == min_.Value() && value_watcher_ == nullptr) {
1435 return s->MakeIsGreaterOrEqualCstVar(this, constant + 1);
1436 }
1437 if (constant == max_.Value() && value_watcher_ == nullptr) {
1438 return s->MakeIsLessOrEqualCstVar(this, constant - 1);
1439 }
1440 if (!Contains(constant)) {
1441 return s->MakeIntConst(int64{1});
1442 }
1443 if (Bound() && min_.Value() == constant) {
1444 return s->MakeIntConst(int64{0});
1445 }
1446 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1448 if (cache != nullptr) {
1449 return cache->Var();
1450 } else {
1451 IntVar* const boolvar = s->MakeDifference(1, IsEqual(constant))->Var();
1452 s->Cache()->InsertExprConstantExpression(
1453 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1454 return boolvar;
1455 }
1456 }
1457
1458 IntVar* IsGreaterOrEqual(int64 constant) override {
1459 Solver* const s = solver();
1460 if (max_.Value() < constant) {
1461 return s->MakeIntConst(int64{0});
1462 }
1463 if (min_.Value() >= constant) {
1464 return s->MakeIntConst(int64{1});
1465 }
1466 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1468 if (cache != nullptr) {
1469 return cache->Var();
1470 } else {
1471 if (bound_watcher_ == nullptr) {
1472 if (CapSub(Max(), Min()) <= 256) {
1473 solver()->SaveAndSetValue(
1474 reinterpret_cast<void**>(&bound_watcher_),
1475 reinterpret_cast<void*>(solver()->RevAlloc(
1476 new DenseUpperBoundWatcher(solver(), this))));
1477 solver()->AddConstraint(bound_watcher_);
1478 } else {
1479 solver()->SaveAndSetValue(
1480 reinterpret_cast<void**>(&bound_watcher_),
1481 reinterpret_cast<void*>(
1482 solver()->RevAlloc(new UpperBoundWatcher(solver(), this))));
1483 solver()->AddConstraint(bound_watcher_);
1484 }
1485 }
1486 IntVar* const boolvar =
1487 bound_watcher_->GetOrMakeUpperBoundWatcher(constant);
1488 s->Cache()->InsertExprConstantExpression(
1489 boolvar, this, constant,
1491 return boolvar;
1492 }
1493 }
1494
1495 Constraint* SetIsGreaterOrEqual(const std::vector<int64>& values,
1496 const std::vector<IntVar*>& vars) {
1497 if (bound_watcher_ == nullptr) {
1498 if (CapSub(Max(), Min()) <= 256) {
1499 solver()->SaveAndSetValue(
1500 reinterpret_cast<void**>(&bound_watcher_),
1501 reinterpret_cast<void*>(solver()->RevAlloc(
1502 new DenseUpperBoundWatcher(solver(), this))));
1503 solver()->AddConstraint(bound_watcher_);
1504 } else {
1505 solver()->SaveAndSetValue(reinterpret_cast<void**>(&bound_watcher_),
1506 reinterpret_cast<void*>(solver()->RevAlloc(
1507 new UpperBoundWatcher(solver(), this))));
1508 solver()->AddConstraint(bound_watcher_);
1509 }
1510 for (int i = 0; i < values.size(); ++i) {
1511 bound_watcher_->SetUpperBoundWatcher(vars[i], values[i]);
1512 }
1513 }
1514 return bound_watcher_;
1515 }
1516
1517 IntVar* IsLessOrEqual(int64 constant) override {
1518 Solver* const s = solver();
1519 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1521 if (cache != nullptr) {
1522 return cache->Var();
1523 } else {
1524 IntVar* const boolvar =
1525 s->MakeDifference(1, IsGreaterOrEqual(constant + 1))->Var();
1526 s->Cache()->InsertExprConstantExpression(
1527 boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
1528 return boolvar;
1529 }
1530 }
1531
1532 void Process();
1533 void Push();
1534 void CleanInProcess();
1535 uint64 Size() const override {
1536 if (bits_ != nullptr) return bits_->Size();
1537 return (static_cast<uint64>(max_.Value()) -
1538 static_cast<uint64>(min_.Value()) + 1);
1539 }
1540 bool Contains(int64 v) const override {
1541 if (v < min_.Value() || v > max_.Value()) return false;
1542 return (bits_ == nullptr ? true : bits_->Contains(v));
1543 }
1544 IntVarIterator* MakeHoleIterator(bool reversible) const override;
1545 IntVarIterator* MakeDomainIterator(bool reversible) const override;
1546 int64 OldMin() const override { return std::min(old_min_, min_.Value()); }
1547 int64 OldMax() const override { return std::max(old_max_, max_.Value()); }
1548
1549 std::string DebugString() const override;
1550 BitSet* bitset() const { return bits_; }
1551 int VarType() const override { return DOMAIN_INT_VAR; }
1552 std::string BaseName() const override { return "IntegerVar"; }
1553
1554 friend class PlusCstDomainIntVar;
1555 friend class LinkExprAndDomainIntVar;
1556
1557 private:
1558 void CheckOldMin() {
1559 if (old_min_ > min_.Value()) {
1560 old_min_ = min_.Value();
1561 }
1562 }
1563 void CheckOldMax() {
1564 if (old_max_ < max_.Value()) {
1565 old_max_ = max_.Value();
1566 }
1567 }
1568 Rev<int64> min_;
1569 Rev<int64> max_;
1570 int64 old_min_;
1571 int64 old_max_;
1572 int64 new_min_;
1573 int64 new_max_;
1574 SimpleRevFIFO<Demon*> bound_demons_;
1575 SimpleRevFIFO<Demon*> range_demons_;
1576 SimpleRevFIFO<Demon*> domain_demons_;
1577 SimpleRevFIFO<Demon*> delayed_bound_demons_;
1578 SimpleRevFIFO<Demon*> delayed_range_demons_;
1579 SimpleRevFIFO<Demon*> delayed_domain_demons_;
1580 QueueHandler handler_;
1581 bool in_process_;
1582 BitSet* bits_;
1583 BaseValueWatcher* value_watcher_;
1584 BaseUpperBoundWatcher* bound_watcher_;
1585};
1586
1587// ----- BitSet -----
1588
1589// Return whether an integer interval [a..b] (inclusive) contains at most
1590// K values, i.e. b - a < K, in a way that's robust to overflows.
1591// For performance reasons, in opt mode it doesn't check that [a, b] is a
1592// valid interval, nor that K is nonnegative.
1593inline bool ClosedIntervalNoLargerThan(int64 a, int64 b, int64 K) {
1594 DCHECK_LE(a, b);
1595 DCHECK_GE(K, 0);
1596 if (a > 0) {
1597 return a > b - K;
1598 } else {
1599 return a + K > b;
1600 }
1601}
1602
1603class SimpleBitSet : public DomainIntVar::BitSet {
1604 public:
1605 SimpleBitSet(Solver* const s, int64 vmin, int64 vmax)
1606 : BitSet(s),
1607 bits_(nullptr),
1608 stamps_(nullptr),
1609 omin_(vmin),
1610 omax_(vmax),
1611 size_(vmax - vmin + 1),
1612 bsize_(BitLength64(size_.Value())) {
1613 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1614 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1615 bits_ = new uint64[bsize_];
1616 stamps_ = new uint64[bsize_];
1617 for (int i = 0; i < bsize_; ++i) {
1618 const int bs =
1619 (i == size_.Value() - 1) ? 63 - BitPos64(size_.Value()) : 0;
1620 bits_[i] = kAllBits64 >> bs;
1621 stamps_[i] = s->stamp() - 1;
1622 }
1623 }
1624
1625 SimpleBitSet(Solver* const s, const std::vector<int64>& sorted_values,
1626 int64 vmin, int64 vmax)
1627 : BitSet(s),
1628 bits_(nullptr),
1629 stamps_(nullptr),
1630 omin_(vmin),
1631 omax_(vmax),
1632 size_(sorted_values.size()),
1633 bsize_(BitLength64(vmax - vmin + 1)) {
1634 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1635 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1636 bits_ = new uint64[bsize_];
1637 stamps_ = new uint64[bsize_];
1638 for (int i = 0; i < bsize_; ++i) {
1639 bits_[i] = uint64_t{0};
1640 stamps_[i] = s->stamp() - 1;
1641 }
1642 for (int i = 0; i < sorted_values.size(); ++i) {
1643 const int64 val = sorted_values[i];
1644 DCHECK(!bit(val));
1645 const int offset = BitOffset64(val - omin_);
1646 const int pos = BitPos64(val - omin_);
1647 bits_[offset] |= OneBit64(pos);
1648 }
1649 }
1650
1651 ~SimpleBitSet() override {
1652 delete[] bits_;
1653 delete[] stamps_;
1654 }
1655
1656 bool bit(int64 val) const { return IsBitSet64(bits_, val - omin_); }
1657
1658 int64 ComputeNewMin(int64 nmin, int64 cmin, int64 cmax) override {
1659 DCHECK_GE(nmin, cmin);
1660 DCHECK_LE(nmin, cmax);
1661 DCHECK_LE(cmin, cmax);
1662 DCHECK_GE(cmin, omin_);
1663 DCHECK_LE(cmax, omax_);
1664 const int64 new_min =
1665 UnsafeLeastSignificantBitPosition64(bits_, nmin - omin_, cmax - omin_) +
1666 omin_;
1667 const uint64 removed_bits =
1668 BitCountRange64(bits_, cmin - omin_, new_min - omin_ - 1);
1669 size_.Add(solver_, -removed_bits);
1670 return new_min;
1671 }
1672
1673 int64 ComputeNewMax(int64 nmax, int64 cmin, int64 cmax) override {
1674 DCHECK_GE(nmax, cmin);
1675 DCHECK_LE(nmax, cmax);
1676 DCHECK_LE(cmin, cmax);
1677 DCHECK_GE(cmin, omin_);
1678 DCHECK_LE(cmax, omax_);
1679 const int64 new_max =
1680 UnsafeMostSignificantBitPosition64(bits_, cmin - omin_, nmax - omin_) +
1681 omin_;
1682 const uint64 removed_bits =
1683 BitCountRange64(bits_, new_max - omin_ + 1, cmax - omin_);
1684 size_.Add(solver_, -removed_bits);
1685 return new_max;
1686 }
1687
1688 bool SetValue(int64 val) override {
1689 DCHECK_GE(val, omin_);
1690 DCHECK_LE(val, omax_);
1691 if (bit(val)) {
1692 size_.SetValue(solver_, 1);
1693 return true;
1694 }
1695 return false;
1696 }
1697
1698 bool Contains(int64 val) const override {
1699 DCHECK_GE(val, omin_);
1700 DCHECK_LE(val, omax_);
1701 return bit(val);
1702 }
1703
1704 bool RemoveValue(int64 val) override {
1705 if (val < omin_ || val > omax_ || !bit(val)) {
1706 return false;
1707 }
1708 // Bitset.
1709 const int64 val_offset = val - omin_;
1710 const int offset = BitOffset64(val_offset);
1711 const uint64 current_stamp = solver_->stamp();
1712 if (stamps_[offset] < current_stamp) {
1713 stamps_[offset] = current_stamp;
1714 solver_->SaveValue(&bits_[offset]);
1715 }
1716 const int pos = BitPos64(val_offset);
1717 bits_[offset] &= ~OneBit64(pos);
1718 // Size.
1719 size_.Decr(solver_);
1720 // Holes.
1721 InitHoles();
1722 AddHole(val);
1723 return true;
1724 }
1725 uint64 Size() const override { return size_.Value(); }
1726
1727 std::string DebugString() const override {
1728 std::string out;
1729 absl::StrAppendFormat(&out, "SimpleBitSet(%d..%d : ", omin_, omax_);
1730 for (int i = 0; i < bsize_; ++i) {
1731 absl::StrAppendFormat(&out, "%x", bits_[i]);
1732 }
1733 out += ")";
1734 return out;
1735 }
1736
1737 void DelayRemoveValue(int64 val) override { removed_.push_back(val); }
1738
1739 void ApplyRemovedValues(DomainIntVar* var) override {
1740 std::sort(removed_.begin(), removed_.end());
1741 for (std::vector<int64>::iterator it = removed_.begin();
1742 it != removed_.end(); ++it) {
1743 var->RemoveValue(*it);
1744 }
1745 }
1746
1747 void ClearRemovedValues() override { removed_.clear(); }
1748
1749 std::string pretty_DebugString(int64 min, int64 max) const override {
1750 std::string out;
1751 DCHECK(bit(min));
1752 DCHECK(bit(max));
1753 if (max != min) {
1754 int cumul = true;
1755 int64 start_cumul = min;
1756 for (int64 v = min + 1; v < max; ++v) {
1757 if (bit(v)) {
1758 if (!cumul) {
1759 cumul = true;
1760 start_cumul = v;
1761 }
1762 } else {
1763 if (cumul) {
1764 if (v == start_cumul + 1) {
1765 absl::StrAppendFormat(&out, "%d ", start_cumul);
1766 } else if (v == start_cumul + 2) {
1767 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1768 } else {
1769 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1770 }
1771 cumul = false;
1772 }
1773 }
1774 }
1775 if (cumul) {
1776 if (max == start_cumul + 1) {
1777 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1778 } else {
1779 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1780 }
1781 } else {
1782 absl::StrAppendFormat(&out, "%d", max);
1783 }
1784 } else {
1785 absl::StrAppendFormat(&out, "%d", min);
1786 }
1787 return out;
1788 }
1789
1790 DomainIntVar::BitSetIterator* MakeIterator() override {
1791 return new DomainIntVar::BitSetIterator(bits_, omin_);
1792 }
1793
1794 private:
1795 uint64* bits_;
1796 uint64* stamps_;
1797 const int64 omin_;
1798 const int64 omax_;
1799 NumericalRev<int64> size_;
1800 const int bsize_;
1801 std::vector<int64> removed_;
1802};
1803
1804// This is a special case where the bitset fits into one 64 bit integer.
1805// In that case, there are no offset to compute.
1806// Overflows are caught by the robust ClosedIntervalNoLargerThan() method.
1807class SmallBitSet : public DomainIntVar::BitSet {
1808 public:
1809 SmallBitSet(Solver* const s, int64 vmin, int64 vmax)
1810 : BitSet(s),
1811 bits_(uint64_t{0}),
1812 stamp_(s->stamp() - 1),
1813 omin_(vmin),
1814 omax_(vmax),
1815 size_(vmax - vmin + 1) {
1816 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1817 bits_ = OneRange64(0, size_.Value() - 1);
1818 }
1819
1820 SmallBitSet(Solver* const s, const std::vector<int64>& sorted_values,
1821 int64 vmin, int64 vmax)
1822 : BitSet(s),
1823 bits_(uint64_t{0}),
1824 stamp_(s->stamp() - 1),
1825 omin_(vmin),
1826 omax_(vmax),
1827 size_(sorted_values.size()) {
1828 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1829 // We know the array is sorted and does not contains duplicate values.
1830 for (int i = 0; i < sorted_values.size(); ++i) {
1831 const int64 val = sorted_values[i];
1832 DCHECK_GE(val, vmin);
1833 DCHECK_LE(val, vmax);
1834 DCHECK(!IsBitSet64(&bits_, val - omin_));
1835 bits_ |= OneBit64(val - omin_);
1836 }
1837 }
1838
1839 ~SmallBitSet() override {}
1840
1841 bool bit(int64 val) const {
1842 DCHECK_GE(val, omin_);
1843 DCHECK_LE(val, omax_);
1844 return (bits_ & OneBit64(val - omin_)) != 0;
1845 }
1846
1847 int64 ComputeNewMin(int64 nmin, int64 cmin, int64 cmax) override {
1848 DCHECK_GE(nmin, cmin);
1849 DCHECK_LE(nmin, cmax);
1850 DCHECK_LE(cmin, cmax);
1851 DCHECK_GE(cmin, omin_);
1852 DCHECK_LE(cmax, omax_);
1853 // We do not clean the bits between cmin and nmin.
1854 // But we use mask to look only at 'active' bits.
1855
1856 // Create the mask and compute new bits
1857 const uint64 new_bits = bits_ & OneRange64(nmin - omin_, cmax - omin_);
1858 if (new_bits != uint64_t{0}) {
1859 // Compute new size and new min
1860 size_.SetValue(solver_, BitCount64(new_bits));
1861 if (bit(nmin)) { // Common case, the new min is inside the bitset
1862 return nmin;
1863 }
1864 return LeastSignificantBitPosition64(new_bits) + omin_;
1865 } else { // == 0 -> Fail()
1866 solver_->Fail();
1867 return kint64max;
1868 }
1869 }
1870
1871 int64 ComputeNewMax(int64 nmax, int64 cmin, int64 cmax) override {
1872 DCHECK_GE(nmax, cmin);
1873 DCHECK_LE(nmax, cmax);
1874 DCHECK_LE(cmin, cmax);
1875 DCHECK_GE(cmin, omin_);
1876 DCHECK_LE(cmax, omax_);
1877 // We do not clean the bits between nmax and cmax.
1878 // But we use mask to look only at 'active' bits.
1879
1880 // Create the mask and compute new_bits
1881 const uint64 new_bits = bits_ & OneRange64(cmin - omin_, nmax - omin_);
1882 if (new_bits != uint64_t{0}) {
1883 // Compute new size and new min
1884 size_.SetValue(solver_, BitCount64(new_bits));
1885 if (bit(nmax)) { // Common case, the new max is inside the bitset
1886 return nmax;
1887 }
1888 return MostSignificantBitPosition64(new_bits) + omin_;
1889 } else { // == 0 -> Fail()
1890 solver_->Fail();
1891 return kint64min;
1892 }
1893 }
1894
1895 bool SetValue(int64 val) override {
1896 DCHECK_GE(val, omin_);
1897 DCHECK_LE(val, omax_);
1898 // We do not clean the bits. We will use masks to ignore the bits
1899 // that should have been cleaned.
1900 if (bit(val)) {
1901 size_.SetValue(solver_, 1);
1902 return true;
1903 }
1904 return false;
1905 }
1906
1907 bool Contains(int64 val) const override {
1908 DCHECK_GE(val, omin_);
1909 DCHECK_LE(val, omax_);
1910 return bit(val);
1911 }
1912
1913 bool RemoveValue(int64 val) override {
1914 DCHECK_GE(val, omin_);
1915 DCHECK_LE(val, omax_);
1916 if (bit(val)) {
1917 // Bitset.
1918 const uint64 current_stamp = solver_->stamp();
1919 if (stamp_ < current_stamp) {
1920 stamp_ = current_stamp;
1921 solver_->SaveValue(&bits_);
1922 }
1923 bits_ &= ~OneBit64(val - omin_);
1924 DCHECK(!bit(val));
1925 // Size.
1926 size_.Decr(solver_);
1927 // Holes.
1928 InitHoles();
1929 AddHole(val);
1930 return true;
1931 } else {
1932 return false;
1933 }
1934 }
1935
1936 uint64 Size() const override { return size_.Value(); }
1937
1938 std::string DebugString() const override {
1939 return absl::StrFormat("SmallBitSet(%d..%d : %llx)", omin_, omax_, bits_);
1940 }
1941
1942 void DelayRemoveValue(int64 val) override {
1943 DCHECK_GE(val, omin_);
1944 DCHECK_LE(val, omax_);
1945 removed_.push_back(val);
1946 }
1947
1948 void ApplyRemovedValues(DomainIntVar* var) override {
1949 std::sort(removed_.begin(), removed_.end());
1950 for (std::vector<int64>::iterator it = removed_.begin();
1951 it != removed_.end(); ++it) {
1952 var->RemoveValue(*it);
1953 }
1954 }
1955
1956 void ClearRemovedValues() override { removed_.clear(); }
1957
1958 std::string pretty_DebugString(int64 min, int64 max) const override {
1959 std::string out;
1960 DCHECK(bit(min));
1961 DCHECK(bit(max));
1962 if (max != min) {
1963 int cumul = true;
1964 int64 start_cumul = min;
1965 for (int64 v = min + 1; v < max; ++v) {
1966 if (bit(v)) {
1967 if (!cumul) {
1968 cumul = true;
1969 start_cumul = v;
1970 }
1971 } else {
1972 if (cumul) {
1973 if (v == start_cumul + 1) {
1974 absl::StrAppendFormat(&out, "%d ", start_cumul);
1975 } else if (v == start_cumul + 2) {
1976 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1977 } else {
1978 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1979 }
1980 cumul = false;
1981 }
1982 }
1983 }
1984 if (cumul) {
1985 if (max == start_cumul + 1) {
1986 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1987 } else {
1988 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1989 }
1990 } else {
1991 absl::StrAppendFormat(&out, "%d", max);
1992 }
1993 } else {
1994 absl::StrAppendFormat(&out, "%d", min);
1995 }
1996 return out;
1997 }
1998
1999 DomainIntVar::BitSetIterator* MakeIterator() override {
2000 return new DomainIntVar::BitSetIterator(&bits_, omin_);
2001 }
2002
2003 private:
2004 uint64 bits_;
2005 uint64 stamp_;
2006 const int64 omin_;
2007 const int64 omax_;
2008 NumericalRev<int64> size_;
2009 std::vector<int64> removed_;
2010};
2011
2012class EmptyIterator : public IntVarIterator {
2013 public:
2014 ~EmptyIterator() override {}
2015 void Init() override {}
2016 bool Ok() const override { return false; }
2017 int64 Value() const override {
2018 LOG(FATAL) << "Should not be called";
2019 return 0LL;
2020 }
2021 void Next() override {}
2022};
2023
2024class RangeIterator : public IntVarIterator {
2025 public:
2026 explicit RangeIterator(const IntVar* const var)
2027 : var_(var), min_(kint64max), max_(kint64min), current_(-1) {}
2028
2029 ~RangeIterator() override {}
2030
2031 void Init() override {
2032 min_ = var_->Min();
2033 max_ = var_->Max();
2034 current_ = min_;
2035 }
2036
2037 bool Ok() const override { return current_ <= max_; }
2038
2039 int64 Value() const override { return current_; }
2040
2041 void Next() override { current_++; }
2042
2043 private:
2044 const IntVar* const var_;
2045 int64 min_;
2046 int64 max_;
2048};
2049
2050class DomainIntVarHoleIterator : public IntVarIterator {
2051 public:
2052 explicit DomainIntVarHoleIterator(const DomainIntVar* const v)
2053 : var_(v), bits_(nullptr), values_(nullptr), size_(0), index_(0) {}
2054
2055 ~DomainIntVarHoleIterator() override {}
2056
2057 void Init() override {
2058 bits_ = var_->bitset();
2059 if (bits_ != nullptr) {
2060 bits_->InitHoles();
2061 values_ = bits_->Holes().data();
2062 size_ = bits_->Holes().size();
2063 } else {
2064 values_ = nullptr;
2065 size_ = 0;
2066 }
2067 index_ = 0;
2068 }
2069
2070 bool Ok() const override { return index_ < size_; }
2071
2072 int64 Value() const override {
2073 DCHECK(bits_ != nullptr);
2074 DCHECK(index_ < size_);
2075 return values_[index_];
2076 }
2077
2078 void Next() override { index_++; }
2079
2080 private:
2081 const DomainIntVar* const var_;
2082 DomainIntVar::BitSet* bits_;
2083 const int64* values_;
2084 int size_;
2085 int index_;
2086};
2087
2088class DomainIntVarDomainIterator : public IntVarIterator {
2089 public:
2090 explicit DomainIntVarDomainIterator(const DomainIntVar* const v,
2091 bool reversible)
2092 : var_(v),
2093 bitset_iterator_(nullptr),
2094 min_(kint64max),
2095 max_(kint64min),
2096 current_(-1),
2097 reversible_(reversible) {}
2098
2099 ~DomainIntVarDomainIterator() override {
2100 if (!reversible_ && bitset_iterator_) {
2101 delete bitset_iterator_;
2102 }
2103 }
2104
2105 void Init() override {
2106 if (var_->bitset() != nullptr && !var_->Bound()) {
2107 if (reversible_) {
2108 if (!bitset_iterator_) {
2109 Solver* const solver = var_->solver();
2110 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2111 bitset_iterator_ = solver->RevAlloc(var_->bitset()->MakeIterator());
2112 }
2113 } else {
2114 if (bitset_iterator_) {
2115 delete bitset_iterator_;
2116 }
2117 bitset_iterator_ = var_->bitset()->MakeIterator();
2118 }
2119 bitset_iterator_->Init(var_->Min(), var_->Max());
2120 } else {
2121 if (bitset_iterator_) {
2122 if (reversible_) {
2123 Solver* const solver = var_->solver();
2124 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2125 } else {
2126 delete bitset_iterator_;
2127 }
2128 bitset_iterator_ = nullptr;
2129 }
2130 min_ = var_->Min();
2131 max_ = var_->Max();
2132 current_ = min_;
2133 }
2134 }
2135
2136 bool Ok() const override {
2137 return bitset_iterator_ ? bitset_iterator_->Ok() : (current_ <= max_);
2138 }
2139
2140 int64 Value() const override {
2141 return bitset_iterator_ ? bitset_iterator_->Value() : current_;
2142 }
2143
2144 void Next() override {
2145 if (bitset_iterator_) {
2146 bitset_iterator_->Next();
2147 } else {
2148 current_++;
2149 }
2150 }
2151
2152 private:
2153 const DomainIntVar* const var_;
2154 DomainIntVar::BitSetIterator* bitset_iterator_;
2155 int64 min_;
2156 int64 max_;
2158 const bool reversible_;
2159};
2160
2161class UnaryIterator : public IntVarIterator {
2162 public:
2163 UnaryIterator(const IntVar* const v, bool hole, bool reversible)
2164 : iterator_(hole ? v->MakeHoleIterator(reversible)
2165 : v->MakeDomainIterator(reversible)),
2166 reversible_(reversible) {}
2167
2168 ~UnaryIterator() override {
2169 if (!reversible_) {
2170 delete iterator_;
2171 }
2172 }
2173
2174 void Init() override { iterator_->Init(); }
2175
2176 bool Ok() const override { return iterator_->Ok(); }
2177
2178 void Next() override { iterator_->Next(); }
2179
2180 protected:
2181 IntVarIterator* const iterator_;
2182 const bool reversible_;
2183};
2184
2185DomainIntVar::DomainIntVar(Solver* const s, int64 vmin, int64 vmax,
2186 const std::string& name)
2187 : IntVar(s, name),
2188 min_(vmin),
2189 max_(vmax),
2190 old_min_(vmin),
2191 old_max_(vmax),
2192 new_min_(vmin),
2193 new_max_(vmax),
2194 handler_(this),
2195 in_process_(false),
2196 bits_(nullptr),
2197 value_watcher_(nullptr),
2198 bound_watcher_(nullptr) {}
2199
2200DomainIntVar::DomainIntVar(Solver* const s,
2201 const std::vector<int64>& sorted_values,
2202 const std::string& name)
2203 : IntVar(s, name),
2204 min_(kint64max),
2205 max_(kint64min),
2206 old_min_(kint64max),
2207 old_max_(kint64min),
2208 new_min_(kint64max),
2209 new_max_(kint64min),
2210 handler_(this),
2211 in_process_(false),
2212 bits_(nullptr),
2213 value_watcher_(nullptr),
2214 bound_watcher_(nullptr) {
2215 CHECK_GE(sorted_values.size(), 1);
2216 // We know that the vector is sorted and does not have duplicate values.
2217 const int64 vmin = sorted_values.front();
2218 const int64 vmax = sorted_values.back();
2219 const bool contiguous = vmax - vmin + 1 == sorted_values.size();
2220
2221 min_.SetValue(solver(), vmin);
2222 old_min_ = vmin;
2223 new_min_ = vmin;
2224 max_.SetValue(solver(), vmax);
2225 old_max_ = vmax;
2226 new_max_ = vmax;
2227
2228 if (!contiguous) {
2229 if (vmax - vmin + 1 < 65) {
2230 bits_ = solver()->RevAlloc(
2231 new SmallBitSet(solver(), sorted_values, vmin, vmax));
2232 } else {
2233 bits_ = solver()->RevAlloc(
2234 new SimpleBitSet(solver(), sorted_values, vmin, vmax));
2235 }
2236 }
2237}
2238
2239DomainIntVar::~DomainIntVar() {}
2240
2241void DomainIntVar::SetMin(int64 m) {
2242 if (m <= min_.Value()) return;
2243 if (m > max_.Value()) solver()->Fail();
2244 if (in_process_) {
2245 if (m > new_min_) {
2246 new_min_ = m;
2247 if (new_min_ > new_max_) {
2248 solver()->Fail();
2249 }
2250 }
2251 } else {
2252 CheckOldMin();
2253 const int64 new_min =
2254 (bits_ == nullptr
2255 ? m
2256 : bits_->ComputeNewMin(m, min_.Value(), max_.Value()));
2257 min_.SetValue(solver(), new_min);
2258 if (min_.Value() > max_.Value()) {
2259 solver()->Fail();
2260 }
2261 Push();
2262 }
2263}
2264
2265void DomainIntVar::SetMax(int64 m) {
2266 if (m >= max_.Value()) return;
2267 if (m < min_.Value()) solver()->Fail();
2268 if (in_process_) {
2269 if (m < new_max_) {
2270 new_max_ = m;
2271 if (new_max_ < new_min_) {
2272 solver()->Fail();
2273 }
2274 }
2275 } else {
2276 CheckOldMax();
2277 const int64 new_max =
2278 (bits_ == nullptr
2279 ? m
2280 : bits_->ComputeNewMax(m, min_.Value(), max_.Value()));
2281 max_.SetValue(solver(), new_max);
2282 if (min_.Value() > max_.Value()) {
2283 solver()->Fail();
2284 }
2285 Push();
2286 }
2287}
2288
2289void DomainIntVar::SetRange(int64 mi, int64 ma) {
2290 if (mi == ma) {
2291 SetValue(mi);
2292 } else {
2293 if (mi > ma || mi > max_.Value() || ma < min_.Value()) solver()->Fail();
2294 if (mi <= min_.Value() && ma >= max_.Value()) return;
2295 if (in_process_) {
2296 if (ma < new_max_) {
2297 new_max_ = ma;
2298 }
2299 if (mi > new_min_) {
2300 new_min_ = mi;
2301 }
2302 if (new_min_ > new_max_) {
2303 solver()->Fail();
2304 }
2305 } else {
2306 if (mi > min_.Value()) {
2307 CheckOldMin();
2308 const int64 new_min =
2309 (bits_ == nullptr
2310 ? mi
2311 : bits_->ComputeNewMin(mi, min_.Value(), max_.Value()));
2312 min_.SetValue(solver(), new_min);
2313 }
2314 if (min_.Value() > ma) {
2315 solver()->Fail();
2316 }
2317 if (ma < max_.Value()) {
2318 CheckOldMax();
2319 const int64 new_max =
2320 (bits_ == nullptr
2321 ? ma
2322 : bits_->ComputeNewMax(ma, min_.Value(), max_.Value()));
2323 max_.SetValue(solver(), new_max);
2324 }
2325 if (min_.Value() > max_.Value()) {
2326 solver()->Fail();
2327 }
2328 Push();
2329 }
2330 }
2331}
2332
2333void DomainIntVar::SetValue(int64 v) {
2334 if (v != min_.Value() || v != max_.Value()) {
2335 if (v < min_.Value() || v > max_.Value()) {
2336 solver()->Fail();
2337 }
2338 if (in_process_) {
2339 if (v > new_max_ || v < new_min_) {
2340 solver()->Fail();
2341 }
2342 new_min_ = v;
2343 new_max_ = v;
2344 } else {
2345 if (bits_ && !bits_->SetValue(v)) {
2346 solver()->Fail();
2347 }
2348 CheckOldMin();
2349 CheckOldMax();
2350 min_.SetValue(solver(), v);
2351 max_.SetValue(solver(), v);
2352 Push();
2353 }
2354 }
2355}
2356
2357void DomainIntVar::RemoveValue(int64 v) {
2358 if (v < min_.Value() || v > max_.Value()) return;
2359 if (v == min_.Value()) {
2360 SetMin(v + 1);
2361 } else if (v == max_.Value()) {
2362 SetMax(v - 1);
2363 } else {
2364 if (bits_ == nullptr) {
2365 CreateBits();
2366 }
2367 if (in_process_) {
2368 if (v >= new_min_ && v <= new_max_ && bits_->Contains(v)) {
2369 bits_->DelayRemoveValue(v);
2370 }
2371 } else {
2372 if (bits_->RemoveValue(v)) {
2373 Push();
2374 }
2375 }
2376 }
2377}
2378
2379void DomainIntVar::RemoveInterval(int64 l, int64 u) {
2380 if (l <= min_.Value()) {
2381 SetMin(u + 1);
2382 } else if (u >= max_.Value()) {
2383 SetMax(l - 1);
2384 } else {
2385 for (int64 v = l; v <= u; ++v) {
2386 RemoveValue(v);
2387 }
2388 }
2389}
2390
2391void DomainIntVar::CreateBits() {
2392 solver()->SaveValue(reinterpret_cast<void**>(&bits_));
2393 if (max_.Value() - min_.Value() < 64) {
2394 bits_ = solver()->RevAlloc(
2395 new SmallBitSet(solver(), min_.Value(), max_.Value()));
2396 } else {
2397 bits_ = solver()->RevAlloc(
2398 new SimpleBitSet(solver(), min_.Value(), max_.Value()));
2399 }
2400}
2401
2402void DomainIntVar::CleanInProcess() {
2403 in_process_ = false;
2404 if (bits_ != nullptr) {
2405 bits_->ClearHoles();
2406 }
2407}
2408
2409void DomainIntVar::Push() {
2410 const bool in_process = in_process_;
2411 EnqueueVar(&handler_);
2412 CHECK_EQ(in_process, in_process_);
2413}
2414
2415void DomainIntVar::Process() {
2417 in_process_ = true;
2418 if (bits_ != nullptr) {
2419 bits_->ClearRemovedValues();
2420 }
2421 set_variable_to_clean_on_fail(this);
2422 new_min_ = min_.Value();
2423 new_max_ = max_.Value();
2424 const bool is_bound = min_.Value() == max_.Value();
2425 const bool range_changed =
2426 min_.Value() != OldMin() || max_.Value() != OldMax();
2427 // Process immediate demons.
2428 if (is_bound) {
2429 ExecuteAll(bound_demons_);
2430 }
2431 if (range_changed) {
2432 ExecuteAll(range_demons_);
2433 }
2434 ExecuteAll(domain_demons_);
2435
2436 // Process delayed demons.
2437 if (is_bound) {
2438 EnqueueAll(delayed_bound_demons_);
2439 }
2440 if (range_changed) {
2441 EnqueueAll(delayed_range_demons_);
2442 }
2443 EnqueueAll(delayed_domain_demons_);
2444
2445 // Everything went well if we arrive here. Let's clean the variable.
2446 set_variable_to_clean_on_fail(nullptr);
2447 CleanInProcess();
2448 old_min_ = min_.Value();
2449 old_max_ = max_.Value();
2450 if (min_.Value() < new_min_) {
2451 SetMin(new_min_);
2452 }
2453 if (max_.Value() > new_max_) {
2454 SetMax(new_max_);
2455 }
2456 if (bits_ != nullptr) {
2457 bits_->ApplyRemovedValues(this);
2458 }
2459}
2460
2461#define COND_REV_ALLOC(rev, alloc) rev ? solver()->RevAlloc(alloc) : alloc;
2462
2463IntVarIterator* DomainIntVar::MakeHoleIterator(bool reversible) const {
2464 return COND_REV_ALLOC(reversible, new DomainIntVarHoleIterator(this));
2465}
2466
2467IntVarIterator* DomainIntVar::MakeDomainIterator(bool reversible) const {
2468 return COND_REV_ALLOC(reversible,
2469 new DomainIntVarDomainIterator(this, reversible));
2470}
2471
2472std::string DomainIntVar::DebugString() const {
2473 std::string out;
2474 const std::string& var_name = name();
2475 if (!var_name.empty()) {
2476 out = var_name + "(";
2477 } else {
2478 out = "DomainIntVar(";
2479 }
2480 if (min_.Value() == max_.Value()) {
2481 absl::StrAppendFormat(&out, "%d", min_.Value());
2482 } else if (bits_ != nullptr) {
2483 out.append(bits_->pretty_DebugString(min_.Value(), max_.Value()));
2484 } else {
2485 absl::StrAppendFormat(&out, "%d..%d", min_.Value(), max_.Value());
2486 }
2487 out += ")";
2488 return out;
2489}
2490
2491// ----- Real Boolean Var -----
2492
2493class ConcreteBooleanVar : public BooleanVar {
2494 public:
2495 // Utility classes
2496 class Handler : public Demon {
2497 public:
2498 explicit Handler(ConcreteBooleanVar* const var) : Demon(), var_(var) {}
2499 ~Handler() override {}
2500 void Run(Solver* const s) override {
2501 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
2502 var_->Process();
2503 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
2504 }
2505 Solver::DemonPriority priority() const override {
2506 return Solver::VAR_PRIORITY;
2507 }
2508 std::string DebugString() const override {
2509 return absl::StrFormat("Handler(%s)", var_->DebugString());
2510 }
2511
2512 private:
2513 ConcreteBooleanVar* const var_;
2514 };
2515
2516 ConcreteBooleanVar(Solver* const s, const std::string& name)
2517 : BooleanVar(s, name), handler_(this) {}
2518
2519 ~ConcreteBooleanVar() override {}
2520
2521 void SetValue(int64 v) override {
2522 if (value_ == kUnboundBooleanVarValue) {
2523 if ((v & 0xfffffffffffffffe) == 0) {
2524 InternalSaveBooleanVarValue(solver(), this);
2525 value_ = static_cast<int>(v);
2526 EnqueueVar(&handler_);
2527 return;
2528 }
2529 } else if (v == value_) {
2530 return;
2531 }
2532 solver()->Fail();
2533 }
2534
2535 void Process() {
2536 DCHECK_NE(value_, kUnboundBooleanVarValue);
2537 ExecuteAll(bound_demons_);
2538 for (SimpleRevFIFO<Demon*>::Iterator it(&delayed_bound_demons_); it.ok();
2539 ++it) {
2540 EnqueueDelayedDemon(*it);
2541 }
2542 }
2543
2544 int64 OldMin() const override { return 0LL; }
2545 int64 OldMax() const override { return 1LL; }
2546 void RestoreValue() override { value_ = kUnboundBooleanVarValue; }
2547
2548 private:
2549 Handler handler_;
2550};
2551
2552// ----- IntConst -----
2553
2554class IntConst : public IntVar {
2555 public:
2556 IntConst(Solver* const s, int64 value, const std::string& name = "")
2557 : IntVar(s, name), value_(value) {}
2558 ~IntConst() override {}
2559
2560 int64 Min() const override { return value_; }
2561 void SetMin(int64 m) override {
2562 if (m > value_) {
2563 solver()->Fail();
2564 }
2565 }
2566 int64 Max() const override { return value_; }
2567 void SetMax(int64 m) override {
2568 if (m < value_) {
2569 solver()->Fail();
2570 }
2571 }
2572 void SetRange(int64 l, int64 u) override {
2573 if (l > value_ || u < value_) {
2574 solver()->Fail();
2575 }
2576 }
2577 void SetValue(int64 v) override {
2578 if (v != value_) {
2579 solver()->Fail();
2580 }
2581 }
2582 bool Bound() const override { return true; }
2583 int64 Value() const override { return value_; }
2584 void RemoveValue(int64 v) override {
2585 if (v == value_) {
2586 solver()->Fail();
2587 }
2588 }
2589 void RemoveInterval(int64 l, int64 u) override {
2590 if (l <= value_ && value_ <= u) {
2591 solver()->Fail();
2592 }
2593 }
2594 void WhenBound(Demon* d) override {}
2595 void WhenRange(Demon* d) override {}
2596 void WhenDomain(Demon* d) override {}
2597 uint64 Size() const override { return 1; }
2598 bool Contains(int64 v) const override { return (v == value_); }
2599 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2600 return COND_REV_ALLOC(reversible, new EmptyIterator());
2601 }
2602 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2603 return COND_REV_ALLOC(reversible, new RangeIterator(this));
2604 }
2605 int64 OldMin() const override { return value_; }
2606 int64 OldMax() const override { return value_; }
2607 std::string DebugString() const override {
2608 std::string out;
2609 if (solver()->HasName(this)) {
2610 const std::string& var_name = name();
2611 absl::StrAppendFormat(&out, "%s(%d)", var_name, value_);
2612 } else {
2613 absl::StrAppendFormat(&out, "IntConst(%d)", value_);
2614 }
2615 return out;
2616 }
2617
2618 int VarType() const override { return CONST_VAR; }
2619
2620 IntVar* IsEqual(int64 constant) override {
2621 if (constant == value_) {
2622 return solver()->MakeIntConst(1);
2623 } else {
2624 return solver()->MakeIntConst(0);
2625 }
2626 }
2627
2628 IntVar* IsDifferent(int64 constant) override {
2629 if (constant == value_) {
2630 return solver()->MakeIntConst(0);
2631 } else {
2632 return solver()->MakeIntConst(1);
2633 }
2634 }
2635
2636 IntVar* IsGreaterOrEqual(int64 constant) override {
2637 return solver()->MakeIntConst(value_ >= constant);
2638 }
2639
2640 IntVar* IsLessOrEqual(int64 constant) override {
2641 return solver()->MakeIntConst(value_ <= constant);
2642 }
2643
2644 std::string name() const override {
2645 if (solver()->HasName(this)) {
2647 } else {
2648 return absl::StrCat(value_);
2649 }
2650 }
2651
2652 private:
2653 int64 value_;
2654};
2655
2656// ----- x + c variable, optimized case -----
2657
2658class PlusCstVar : public IntVar {
2659 public:
2660 PlusCstVar(Solver* const s, IntVar* v, int64 c)
2661 : IntVar(s), var_(v), cst_(c) {}
2662
2663 ~PlusCstVar() override {}
2664
2665 void WhenRange(Demon* d) override { var_->WhenRange(d); }
2666
2667 void WhenBound(Demon* d) override { var_->WhenBound(d); }
2668
2669 void WhenDomain(Demon* d) override { var_->WhenDomain(d); }
2670
2671 int64 OldMin() const override { return CapAdd(var_->OldMin(), cst_); }
2672
2673 int64 OldMax() const override { return CapAdd(var_->OldMax(), cst_); }
2674
2675 std::string DebugString() const override {
2676 if (HasName()) {
2677 return absl::StrFormat("%s(%s + %d)", name(), var_->DebugString(), cst_);
2678 } else {
2679 return absl::StrFormat("(%s + %d)", var_->DebugString(), cst_);
2680 }
2681 }
2682
2683 int VarType() const override { return VAR_ADD_CST; }
2684
2685 void Accept(ModelVisitor* const visitor) const override {
2686 visitor->VisitIntegerVariable(this, ModelVisitor::kSumOperation, cst_,
2687 var_);
2688 }
2689
2690 IntVar* IsEqual(int64 constant) override {
2691 return var_->IsEqual(constant - cst_);
2692 }
2693
2694 IntVar* IsDifferent(int64 constant) override {
2695 return var_->IsDifferent(constant - cst_);
2696 }
2697
2698 IntVar* IsGreaterOrEqual(int64 constant) override {
2699 return var_->IsGreaterOrEqual(constant - cst_);
2700 }
2701
2702 IntVar* IsLessOrEqual(int64 constant) override {
2703 return var_->IsLessOrEqual(constant - cst_);
2704 }
2705
2706 IntVar* SubVar() const { return var_; }
2707
2708 int64 Constant() const { return cst_; }
2709
2710 protected:
2711 IntVar* const var_;
2713};
2714
2715class PlusCstIntVar : public PlusCstVar {
2716 public:
2717 class PlusCstIntVarIterator : public UnaryIterator {
2718 public:
2719 PlusCstIntVarIterator(const IntVar* const v, int64 c, bool hole, bool rev)
2720 : UnaryIterator(v, hole, rev), cst_(c) {}
2721
2722 ~PlusCstIntVarIterator() override {}
2723
2724 int64 Value() const override { return iterator_->Value() + cst_; }
2725
2726 private:
2727 const int64 cst_;
2728 };
2729
2730 PlusCstIntVar(Solver* const s, IntVar* v, int64 c) : PlusCstVar(s, v, c) {}
2731
2732 ~PlusCstIntVar() override {}
2733
2734 int64 Min() const override { return var_->Min() + cst_; }
2735
2736 void SetMin(int64 m) override { var_->SetMin(CapSub(m, cst_)); }
2737
2738 int64 Max() const override { return var_->Max() + cst_; }
2739
2740 void SetMax(int64 m) override { var_->SetMax(CapSub(m, cst_)); }
2741
2742 void SetRange(int64 l, int64 u) override {
2743 var_->SetRange(CapSub(l, cst_), CapSub(u, cst_));
2744 }
2745
2746 void SetValue(int64 v) override { var_->SetValue(v - cst_); }
2747
2748 int64 Value() const override { return var_->Value() + cst_; }
2749
2750 bool Bound() const override { return var_->Bound(); }
2751
2752 void RemoveValue(int64 v) override { var_->RemoveValue(v - cst_); }
2753
2754 void RemoveInterval(int64 l, int64 u) override {
2755 var_->RemoveInterval(l - cst_, u - cst_);
2756 }
2757
2758 uint64 Size() const override { return var_->Size(); }
2759
2760 bool Contains(int64 v) const override { return var_->Contains(v - cst_); }
2761
2762 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2763 return COND_REV_ALLOC(
2764 reversible, new PlusCstIntVarIterator(var_, cst_, true, reversible));
2765 }
2766 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2767 return COND_REV_ALLOC(
2768 reversible, new PlusCstIntVarIterator(var_, cst_, false, reversible));
2769 }
2770};
2771
2772class PlusCstDomainIntVar : public PlusCstVar {
2773 public:
2774 class PlusCstDomainIntVarIterator : public UnaryIterator {
2775 public:
2776 PlusCstDomainIntVarIterator(const IntVar* const v, int64 c, bool hole,
2777 bool reversible)
2778 : UnaryIterator(v, hole, reversible), cst_(c) {}
2779
2780 ~PlusCstDomainIntVarIterator() override {}
2781
2782 int64 Value() const override { return iterator_->Value() + cst_; }
2783
2784 private:
2785 const int64 cst_;
2786 };
2787
2788 PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64 c)
2789 : PlusCstVar(s, v, c) {}
2790
2791 ~PlusCstDomainIntVar() override {}
2792
2793 int64 Min() const override;
2794 void SetMin(int64 m) override;
2795 int64 Max() const override;
2796 void SetMax(int64 m) override;
2797 void SetRange(int64 l, int64 u) override;
2798 void SetValue(int64 v) override;
2799 bool Bound() const override;
2800 int64 Value() const override;
2801 void RemoveValue(int64 v) override;
2802 void RemoveInterval(int64 l, int64 u) override;
2803 uint64 Size() const override;
2804 bool Contains(int64 v) const override;
2805
2806 DomainIntVar* domain_int_var() const {
2807 return reinterpret_cast<DomainIntVar*>(var_);
2808 }
2809
2810 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2811 return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(
2812 var_, cst_, true, reversible));
2813 }
2814 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2815 return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(
2816 var_, cst_, false, reversible));
2817 }
2818};
2819
2820int64 PlusCstDomainIntVar::Min() const {
2821 return domain_int_var()->min_.Value() + cst_;
2822}
2823
2824void PlusCstDomainIntVar::SetMin(int64 m) {
2825 domain_int_var()->DomainIntVar::SetMin(m - cst_);
2826}
2827
2828int64 PlusCstDomainIntVar::Max() const {
2829 return domain_int_var()->max_.Value() + cst_;
2830}
2831
2832void PlusCstDomainIntVar::SetMax(int64 m) {
2833 domain_int_var()->DomainIntVar::SetMax(m - cst_);
2834}
2835
2836void PlusCstDomainIntVar::SetRange(int64 l, int64 u) {
2837 domain_int_var()->DomainIntVar::SetRange(l - cst_, u - cst_);
2838}
2839
2840void PlusCstDomainIntVar::SetValue(int64 v) {
2841 domain_int_var()->DomainIntVar::SetValue(v - cst_);
2842}
2843
2844bool PlusCstDomainIntVar::Bound() const {
2845 return domain_int_var()->min_.Value() == domain_int_var()->max_.Value();
2846}
2847
2849 CHECK_EQ(domain_int_var()->min_.Value(), domain_int_var()->max_.Value())
2850 << " variable is not bound";
2851 return domain_int_var()->min_.Value() + cst_;
2852}
2853
2854void PlusCstDomainIntVar::RemoveValue(int64 v) {
2855 domain_int_var()->DomainIntVar::RemoveValue(v - cst_);
2856}
2857
2858void PlusCstDomainIntVar::RemoveInterval(int64 l, int64 u) {
2859 domain_int_var()->DomainIntVar::RemoveInterval(l - cst_, u - cst_);
2860}
2861
2862uint64 PlusCstDomainIntVar::Size() const {
2863 return domain_int_var()->DomainIntVar::Size();
2864}
2865
2866bool PlusCstDomainIntVar::Contains(int64 v) const {
2867 return domain_int_var()->DomainIntVar::Contains(v - cst_);
2868}
2869
2870// c - x variable, optimized case
2871
2872class SubCstIntVar : public IntVar {
2873 public:
2874 class SubCstIntVarIterator : public UnaryIterator {
2875 public:
2876 SubCstIntVarIterator(const IntVar* const v, int64 c, bool hole, bool rev)
2877 : UnaryIterator(v, hole, rev), cst_(c) {}
2878 ~SubCstIntVarIterator() override {}
2879
2880 int64 Value() const override { return cst_ - iterator_->Value(); }
2881
2882 private:
2883 const int64 cst_;
2884 };
2885
2886 SubCstIntVar(Solver* const s, IntVar* v, int64 c);
2887 ~SubCstIntVar() override;
2888
2889 int64 Min() const override;
2890 void SetMin(int64 m) override;
2891 int64 Max() const override;
2892 void SetMax(int64 m) override;
2893 void SetRange(int64 l, int64 u) override;
2894 void SetValue(int64 v) override;
2895 bool Bound() const override;
2896 int64 Value() const override;
2897 void RemoveValue(int64 v) override;
2898 void RemoveInterval(int64 l, int64 u) override;
2899 uint64 Size() const override;
2900 bool Contains(int64 v) const override;
2901 void WhenRange(Demon* d) override;
2902 void WhenBound(Demon* d) override;
2903 void WhenDomain(Demon* d) override;
2904 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2905 return COND_REV_ALLOC(
2906 reversible, new SubCstIntVarIterator(var_, cst_, true, reversible));
2907 }
2908 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2909 return COND_REV_ALLOC(
2910 reversible, new SubCstIntVarIterator(var_, cst_, false, reversible));
2911 }
2912 int64 OldMin() const override { return CapSub(cst_, var_->OldMax()); }
2913 int64 OldMax() const override { return CapSub(cst_, var_->OldMin()); }
2914 std::string DebugString() const override;
2915 std::string name() const override;
2916 int VarType() const override { return CST_SUB_VAR; }
2917
2918 void Accept(ModelVisitor* const visitor) const override {
2919 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation,
2920 cst_, var_);
2921 }
2922
2923 IntVar* IsEqual(int64 constant) override {
2924 return var_->IsEqual(cst_ - constant);
2925 }
2926
2927 IntVar* IsDifferent(int64 constant) override {
2928 return var_->IsDifferent(cst_ - constant);
2929 }
2930
2931 IntVar* IsGreaterOrEqual(int64 constant) override {
2932 return var_->IsLessOrEqual(cst_ - constant);
2933 }
2934
2935 IntVar* IsLessOrEqual(int64 constant) override {
2936 return var_->IsGreaterOrEqual(cst_ - constant);
2937 }
2938
2939 IntVar* SubVar() const { return var_; }
2940 int64 Constant() const { return cst_; }
2941
2942 private:
2943 IntVar* const var_;
2944 const int64 cst_;
2945};
2946
2947SubCstIntVar::SubCstIntVar(Solver* const s, IntVar* v, int64 c)
2948 : IntVar(s), var_(v), cst_(c) {}
2949
2950SubCstIntVar::~SubCstIntVar() {}
2951
2952int64 SubCstIntVar::Min() const { return cst_ - var_->Max(); }
2953
2954void SubCstIntVar::SetMin(int64 m) { var_->SetMax(CapSub(cst_, m)); }
2955
2956int64 SubCstIntVar::Max() const { return cst_ - var_->Min(); }
2957
2958void SubCstIntVar::SetMax(int64 m) { var_->SetMin(CapSub(cst_, m)); }
2959
2960void SubCstIntVar::SetRange(int64 l, int64 u) {
2961 var_->SetRange(CapSub(cst_, u), CapSub(cst_, l));
2962}
2963
2964void SubCstIntVar::SetValue(int64 v) { var_->SetValue(cst_ - v); }
2965
2966bool SubCstIntVar::Bound() const { return var_->Bound(); }
2967
2968void SubCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
2969
2970int64 SubCstIntVar::Value() const { return cst_ - var_->Value(); }
2971
2972void SubCstIntVar::RemoveValue(int64 v) { var_->RemoveValue(cst_ - v); }
2973
2974void SubCstIntVar::RemoveInterval(int64 l, int64 u) {
2975 var_->RemoveInterval(cst_ - u, cst_ - l);
2976}
2977
2978void SubCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
2979
2980void SubCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
2981
2982uint64 SubCstIntVar::Size() const { return var_->Size(); }
2983
2984bool SubCstIntVar::Contains(int64 v) const { return var_->Contains(cst_ - v); }
2985
2986std::string SubCstIntVar::DebugString() const {
2987 if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
2988 return absl::StrFormat("Not(%s)", var_->DebugString());
2989 } else {
2990 return absl::StrFormat("(%d - %s)", cst_, var_->DebugString());
2991 }
2992}
2993
2994std::string SubCstIntVar::name() const {
2995 if (solver()->HasName(this)) {
2997 } else if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
2998 return absl::StrFormat("Not(%s)", var_->name());
2999 } else {
3000 return absl::StrFormat("(%d - %s)", cst_, var_->name());
3001 }
3002}
3003
3004// -x variable, optimized case
3005
3006class OppIntVar : public IntVar {
3007 public:
3008 class OppIntVarIterator : public UnaryIterator {
3009 public:
3010 OppIntVarIterator(const IntVar* const v, bool hole, bool reversible)
3011 : UnaryIterator(v, hole, reversible) {}
3012 ~OppIntVarIterator() override {}
3013
3014 int64 Value() const override { return -iterator_->Value(); }
3015 };
3016
3017 OppIntVar(Solver* const s, IntVar* v);
3018 ~OppIntVar() override;
3019
3020 int64 Min() const override;
3021 void SetMin(int64 m) override;
3022 int64 Max() const override;
3023 void SetMax(int64 m) override;
3024 void SetRange(int64 l, int64 u) override;
3025 void SetValue(int64 v) override;
3026 bool Bound() const override;
3027 int64 Value() const override;
3028 void RemoveValue(int64 v) override;
3029 void RemoveInterval(int64 l, int64 u) override;
3030 uint64 Size() const override;
3031 bool Contains(int64 v) const override;
3032 void WhenRange(Demon* d) override;
3033 void WhenBound(Demon* d) override;
3034 void WhenDomain(Demon* d) override;
3035 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3036 return COND_REV_ALLOC(reversible,
3037 new OppIntVarIterator(var_, true, reversible));
3038 }
3039 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3040 return COND_REV_ALLOC(reversible,
3041 new OppIntVarIterator(var_, false, reversible));
3042 }
3043 int64 OldMin() const override { return CapOpp(var_->OldMax()); }
3044 int64 OldMax() const override { return CapOpp(var_->OldMin()); }
3045 std::string DebugString() const override;
3046 int VarType() const override { return OPP_VAR; }
3047
3048 void Accept(ModelVisitor* const visitor) const override {
3049 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation, 0,
3050 var_);
3051 }
3052
3053 IntVar* IsEqual(int64 constant) override { return var_->IsEqual(-constant); }
3054
3055 IntVar* IsDifferent(int64 constant) override {
3056 return var_->IsDifferent(-constant);
3057 }
3058
3059 IntVar* IsGreaterOrEqual(int64 constant) override {
3060 return var_->IsLessOrEqual(-constant);
3061 }
3062
3063 IntVar* IsLessOrEqual(int64 constant) override {
3064 return var_->IsGreaterOrEqual(-constant);
3065 }
3066
3067 IntVar* SubVar() const { return var_; }
3068
3069 private:
3070 IntVar* const var_;
3071};
3072
3073OppIntVar::OppIntVar(Solver* const s, IntVar* v) : IntVar(s), var_(v) {}
3074
3075OppIntVar::~OppIntVar() {}
3076
3077int64 OppIntVar::Min() const { return -var_->Max(); }
3078
3079void OppIntVar::SetMin(int64 m) { var_->SetMax(CapOpp(m)); }
3080
3081int64 OppIntVar::Max() const { return -var_->Min(); }
3082
3083void OppIntVar::SetMax(int64 m) { var_->SetMin(CapOpp(m)); }
3084
3085void OppIntVar::SetRange(int64 l, int64 u) {
3086 var_->SetRange(CapOpp(u), CapOpp(l));
3087}
3088
3089void OppIntVar::SetValue(int64 v) { var_->SetValue(CapOpp(v)); }
3090
3091bool OppIntVar::Bound() const { return var_->Bound(); }
3092
3093void OppIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3094
3095int64 OppIntVar::Value() const { return -var_->Value(); }
3096
3097void OppIntVar::RemoveValue(int64 v) { var_->RemoveValue(-v); }
3098
3099void OppIntVar::RemoveInterval(int64 l, int64 u) {
3100 var_->RemoveInterval(-u, -l);
3101}
3102
3103void OppIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3104
3105void OppIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3106
3107uint64 OppIntVar::Size() const { return var_->Size(); }
3108
3109bool OppIntVar::Contains(int64 v) const { return var_->Contains(-v); }
3110
3111std::string OppIntVar::DebugString() const {
3112 return absl::StrFormat("-(%s)", var_->DebugString());
3113}
3114
3115// ----- Utility functions -----
3116
3117// x * c variable, optimized case
3118
3119class TimesCstIntVar : public IntVar {
3120 public:
3121 TimesCstIntVar(Solver* const s, IntVar* v, int64 c)
3122 : IntVar(s), var_(v), cst_(c) {}
3123 ~TimesCstIntVar() override {}
3124
3125 IntVar* SubVar() const { return var_; }
3126 int64 Constant() const { return cst_; }
3127
3128 void Accept(ModelVisitor* const visitor) const override {
3129 visitor->VisitIntegerVariable(this, ModelVisitor::kProductOperation, cst_,
3130 var_);
3131 }
3132
3133 IntVar* IsEqual(int64 constant) override {
3134 if (constant % cst_ == 0) {
3135 return var_->IsEqual(constant / cst_);
3136 } else {
3137 return solver()->MakeIntConst(0);
3138 }
3139 }
3140
3141 IntVar* IsDifferent(int64 constant) override {
3142 if (constant % cst_ == 0) {
3143 return var_->IsDifferent(constant / cst_);
3144 } else {
3145 return solver()->MakeIntConst(1);
3146 }
3147 }
3148
3149 IntVar* IsGreaterOrEqual(int64 constant) override {
3150 if (cst_ > 0) {
3151 return var_->IsGreaterOrEqual(PosIntDivUp(constant, cst_));
3152 } else {
3153 return var_->IsLessOrEqual(PosIntDivDown(-constant, -cst_));
3154 }
3155 }
3156
3157 IntVar* IsLessOrEqual(int64 constant) override {
3158 if (cst_ > 0) {
3159 return var_->IsLessOrEqual(PosIntDivDown(constant, cst_));
3160 } else {
3161 return var_->IsGreaterOrEqual(PosIntDivUp(-constant, -cst_));
3162 }
3163 }
3164
3165 std::string DebugString() const override {
3166 return absl::StrFormat("(%s * %d)", var_->DebugString(), cst_);
3167 }
3168
3169 int VarType() const override { return VAR_TIMES_CST; }
3170
3171 protected:
3172 IntVar* const var_;
3173 const int64 cst_;
3174};
3175
3176class TimesPosCstIntVar : public TimesCstIntVar {
3177 public:
3178 class TimesPosCstIntVarIterator : public UnaryIterator {
3179 public:
3180 TimesPosCstIntVarIterator(const IntVar* const v, int64 c, bool hole,
3181 bool reversible)
3182 : UnaryIterator(v, hole, reversible), cst_(c) {}
3183 ~TimesPosCstIntVarIterator() override {}
3184
3185 int64 Value() const override { return iterator_->Value() * cst_; }
3186
3187 private:
3188 const int64 cst_;
3189 };
3190
3191 TimesPosCstIntVar(Solver* const s, IntVar* v, int64 c);
3192 ~TimesPosCstIntVar() override;
3193
3194 int64 Min() const override;
3195 void SetMin(int64 m) override;
3196 int64 Max() const override;
3197 void SetMax(int64 m) override;
3198 void SetRange(int64 l, int64 u) override;
3199 void SetValue(int64 v) override;
3200 bool Bound() const override;
3201 int64 Value() const override;
3202 void RemoveValue(int64 v) override;
3203 void RemoveInterval(int64 l, int64 u) override;
3204 uint64 Size() const override;
3205 bool Contains(int64 v) const override;
3206 void WhenRange(Demon* d) override;
3207 void WhenBound(Demon* d) override;
3208 void WhenDomain(Demon* d) override;
3209 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3210 return COND_REV_ALLOC(reversible, new TimesPosCstIntVarIterator(
3211 var_, cst_, true, reversible));
3212 }
3213 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3214 return COND_REV_ALLOC(reversible, new TimesPosCstIntVarIterator(
3215 var_, cst_, false, reversible));
3216 }
3217 int64 OldMin() const override { return CapProd(var_->OldMin(), cst_); }
3218 int64 OldMax() const override { return CapProd(var_->OldMax(), cst_); }
3219};
3220
3221// ----- TimesPosCstIntVar -----
3222
3223TimesPosCstIntVar::TimesPosCstIntVar(Solver* const s, IntVar* v, int64 c)
3224 : TimesCstIntVar(s, v, c) {}
3225
3226TimesPosCstIntVar::~TimesPosCstIntVar() {}
3227
3228int64 TimesPosCstIntVar::Min() const { return CapProd(var_->Min(), cst_); }
3229
3230void TimesPosCstIntVar::SetMin(int64 m) {
3231 if (m != kint64min) {
3232 var_->SetMin(PosIntDivUp(m, cst_));
3233 }
3234}
3235
3236int64 TimesPosCstIntVar::Max() const { return CapProd(var_->Max(), cst_); }
3237
3238void TimesPosCstIntVar::SetMax(int64 m) {
3239 if (m != kint64max) {
3240 var_->SetMax(PosIntDivDown(m, cst_));
3241 }
3242}
3243
3244void TimesPosCstIntVar::SetRange(int64 l, int64 u) {
3245 var_->SetRange(PosIntDivUp(l, cst_), PosIntDivDown(u, cst_));
3246}
3247
3248void TimesPosCstIntVar::SetValue(int64 v) {
3249 if (v % cst_ != 0) {
3250 solver()->Fail();
3251 }
3252 var_->SetValue(v / cst_);
3253}
3254
3255bool TimesPosCstIntVar::Bound() const { return var_->Bound(); }
3256
3257void TimesPosCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3258
3259int64 TimesPosCstIntVar::Value() const { return CapProd(var_->Value(), cst_); }
3260
3261void TimesPosCstIntVar::RemoveValue(int64 v) {
3262 if (v % cst_ == 0) {
3263 var_->RemoveValue(v / cst_);
3264 }
3265}
3266
3267void TimesPosCstIntVar::RemoveInterval(int64 l, int64 u) {
3268 for (int64 v = l; v <= u; ++v) {
3269 RemoveValue(v);
3270 }
3271 // TODO(user) : Improve me
3272}
3273
3274void TimesPosCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3275
3276void TimesPosCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3277
3278uint64 TimesPosCstIntVar::Size() const { return var_->Size(); }
3279
3280bool TimesPosCstIntVar::Contains(int64 v) const {
3281 return (v % cst_ == 0 && var_->Contains(v / cst_));
3282}
3283
3284// b * c variable, optimized case
3285
3286class TimesPosCstBoolVar : public TimesCstIntVar {
3287 public:
3288 class TimesPosCstBoolVarIterator : public UnaryIterator {
3289 public:
3290 // TODO(user) : optimize this.
3291 TimesPosCstBoolVarIterator(const IntVar* const v, int64 c, bool hole,
3292 bool reversible)
3293 : UnaryIterator(v, hole, reversible), cst_(c) {}
3294 ~TimesPosCstBoolVarIterator() override {}
3295
3296 int64 Value() const override { return iterator_->Value() * cst_; }
3297
3298 private:
3299 const int64 cst_;
3300 };
3301
3302 TimesPosCstBoolVar(Solver* const s, BooleanVar* v, int64 c);
3303 ~TimesPosCstBoolVar() override;
3304
3305 int64 Min() const override;
3306 void SetMin(int64 m) override;
3307 int64 Max() const override;
3308 void SetMax(int64 m) override;
3309 void SetRange(int64 l, int64 u) override;
3310 void SetValue(int64 v) override;
3311 bool Bound() const override;
3312 int64 Value() const override;
3313 void RemoveValue(int64 v) override;
3314 void RemoveInterval(int64 l, int64 u) override;
3315 uint64 Size() const override;
3316 bool Contains(int64 v) const override;
3317 void WhenRange(Demon* d) override;
3318 void WhenBound(Demon* d) override;
3319 void WhenDomain(Demon* d) override;
3320 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3321 return COND_REV_ALLOC(reversible, new EmptyIterator());
3322 }
3323 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3324 return COND_REV_ALLOC(
3325 reversible,
3326 new TimesPosCstBoolVarIterator(boolean_var(), cst_, false, reversible));
3327 }
3328 int64 OldMin() const override { return 0; }
3329 int64 OldMax() const override { return cst_; }
3330
3331 BooleanVar* boolean_var() const {
3332 return reinterpret_cast<BooleanVar*>(var_);
3333 }
3334};
3335
3336// ----- TimesPosCstBoolVar -----
3337
3338TimesPosCstBoolVar::TimesPosCstBoolVar(Solver* const s, BooleanVar* v, int64 c)
3339 : TimesCstIntVar(s, v, c) {}
3340
3341TimesPosCstBoolVar::~TimesPosCstBoolVar() {}
3342
3343int64 TimesPosCstBoolVar::Min() const {
3344 return (boolean_var()->RawValue() == 1) * cst_;
3345}
3346
3347void TimesPosCstBoolVar::SetMin(int64 m) {
3348 if (m > cst_) {
3349 solver()->Fail();
3350 } else if (m > 0) {
3351 boolean_var()->SetMin(1);
3352 }
3353}
3354
3355int64 TimesPosCstBoolVar::Max() const {
3356 return (boolean_var()->RawValue() != 0) * cst_;
3357}
3358
3359void TimesPosCstBoolVar::SetMax(int64 m) {
3360 if (m < 0) {
3361 solver()->Fail();
3362 } else if (m < cst_) {
3363 boolean_var()->SetMax(0);
3364 }
3365}
3366
3367void TimesPosCstBoolVar::SetRange(int64 l, int64 u) {
3368 if (u < 0 || l > cst_ || l > u) {
3369 solver()->Fail();
3370 }
3371 if (l > 0) {
3372 boolean_var()->SetMin(1);
3373 } else if (u < cst_) {
3374 boolean_var()->SetMax(0);
3375 }
3376}
3377
3378void TimesPosCstBoolVar::SetValue(int64 v) {
3379 if (v == 0) {
3380 boolean_var()->SetValue(0);
3381 } else if (v == cst_) {
3382 boolean_var()->SetValue(1);
3383 } else {
3384 solver()->Fail();
3385 }
3386}
3387
3388bool TimesPosCstBoolVar::Bound() const {
3389 return boolean_var()->RawValue() != BooleanVar::kUnboundBooleanVarValue;
3390}
3391
3392void TimesPosCstBoolVar::WhenRange(Demon* d) { boolean_var()->WhenRange(d); }
3393
3395 CHECK_NE(boolean_var()->RawValue(), BooleanVar::kUnboundBooleanVarValue)
3396 << " variable is not bound";
3397 return boolean_var()->RawValue() * cst_;
3398}
3399
3400void TimesPosCstBoolVar::RemoveValue(int64 v) {
3401 if (v == 0) {
3402 boolean_var()->RemoveValue(0);
3403 } else if (v == cst_) {
3404 boolean_var()->RemoveValue(1);
3405 }
3406}
3407
3408void TimesPosCstBoolVar::RemoveInterval(int64 l, int64 u) {
3409 if (l <= 0 && u >= 0) {
3410 boolean_var()->RemoveValue(0);
3411 }
3412 if (l <= cst_ && u >= cst_) {
3413 boolean_var()->RemoveValue(1);
3414 }
3415}
3416
3417void TimesPosCstBoolVar::WhenBound(Demon* d) { boolean_var()->WhenBound(d); }
3418
3419void TimesPosCstBoolVar::WhenDomain(Demon* d) { boolean_var()->WhenDomain(d); }
3420
3421uint64 TimesPosCstBoolVar::Size() const {
3422 return (1 +
3423 (boolean_var()->RawValue() == BooleanVar::kUnboundBooleanVarValue));
3424}
3425
3426bool TimesPosCstBoolVar::Contains(int64 v) const {
3427 if (v == 0) {
3428 return boolean_var()->RawValue() != 1;
3429 } else if (v == cst_) {
3430 return boolean_var()->RawValue() != 0;
3431 }
3432 return false;
3433}
3434
3435// TimesNegCstIntVar
3436
3437class TimesNegCstIntVar : public TimesCstIntVar {
3438 public:
3439 class TimesNegCstIntVarIterator : public UnaryIterator {
3440 public:
3441 TimesNegCstIntVarIterator(const IntVar* const v, int64 c, bool hole,
3442 bool reversible)
3443 : UnaryIterator(v, hole, reversible), cst_(c) {}
3444 ~TimesNegCstIntVarIterator() override {}
3445
3446 int64 Value() const override { return iterator_->Value() * cst_; }
3447
3448 private:
3449 const int64 cst_;
3450 };
3451
3452 TimesNegCstIntVar(Solver* const s, IntVar* v, int64 c);
3453 ~TimesNegCstIntVar() override;
3454
3455 int64 Min() const override;
3456 void SetMin(int64 m) override;
3457 int64 Max() const override;
3458 void SetMax(int64 m) override;
3459 void SetRange(int64 l, int64 u) override;
3460 void SetValue(int64 v) override;
3461 bool Bound() const override;
3462 int64 Value() const override;
3463 void RemoveValue(int64 v) override;
3464 void RemoveInterval(int64 l, int64 u) override;
3465 uint64 Size() const override;
3466 bool Contains(int64 v) const override;
3467 void WhenRange(Demon* d) override;
3468 void WhenBound(Demon* d) override;
3469 void WhenDomain(Demon* d) override;
3470 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3471 return COND_REV_ALLOC(reversible, new TimesNegCstIntVarIterator(
3472 var_, cst_, true, reversible));
3473 }
3474 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3475 return COND_REV_ALLOC(reversible, new TimesNegCstIntVarIterator(
3476 var_, cst_, false, reversible));
3477 }
3478 int64 OldMin() const override { return CapProd(var_->OldMax(), cst_); }
3479 int64 OldMax() const override { return CapProd(var_->OldMin(), cst_); }
3480};
3481
3482// ----- TimesNegCstIntVar -----
3483
3484TimesNegCstIntVar::TimesNegCstIntVar(Solver* const s, IntVar* v, int64 c)
3485 : TimesCstIntVar(s, v, c) {}
3486
3487TimesNegCstIntVar::~TimesNegCstIntVar() {}
3488
3489int64 TimesNegCstIntVar::Min() const { return CapProd(var_->Max(), cst_); }
3490
3491void TimesNegCstIntVar::SetMin(int64 m) {
3492 if (m != kint64min) {
3493 var_->SetMax(PosIntDivDown(-m, -cst_));
3494 }
3495}
3496
3497int64 TimesNegCstIntVar::Max() const { return CapProd(var_->Min(), cst_); }
3498
3499void TimesNegCstIntVar::SetMax(int64 m) {
3500 if (m != kint64max) {
3501 var_->SetMin(PosIntDivUp(-m, -cst_));
3502 }
3503}
3504
3505void TimesNegCstIntVar::SetRange(int64 l, int64 u) {
3506 var_->SetRange(PosIntDivUp(-u, -cst_), PosIntDivDown(-l, -cst_));
3507}
3508
3509void TimesNegCstIntVar::SetValue(int64 v) {
3510 if (v % cst_ != 0) {
3511 solver()->Fail();
3512 }
3513 var_->SetValue(v / cst_);
3514}
3515
3516bool TimesNegCstIntVar::Bound() const { return var_->Bound(); }
3517
3518void TimesNegCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3519
3520int64 TimesNegCstIntVar::Value() const { return CapProd(var_->Value(), cst_); }
3521
3522void TimesNegCstIntVar::RemoveValue(int64 v) {
3523 if (v % cst_ == 0) {
3524 var_->RemoveValue(v / cst_);
3525 }
3526}
3527
3528void TimesNegCstIntVar::RemoveInterval(int64 l, int64 u) {
3529 for (int64 v = l; v <= u; ++v) {
3530 RemoveValue(v);
3531 }
3532 // TODO(user) : Improve me
3533}
3534
3535void TimesNegCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3536
3537void TimesNegCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3538
3539uint64 TimesNegCstIntVar::Size() const { return var_->Size(); }
3540
3541bool TimesNegCstIntVar::Contains(int64 v) const {
3542 return (v % cst_ == 0 && var_->Contains(v / cst_));
3543}
3544
3545// ---------- arithmetic expressions ----------
3546
3547// ----- PlusIntExpr -----
3548
3549class PlusIntExpr : public BaseIntExpr {
3550 public:
3551 PlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3552 : BaseIntExpr(s), left_(l), right_(r) {}
3553
3554 ~PlusIntExpr() override {}
3555
3556 int64 Min() const override { return left_->Min() + right_->Min(); }
3557
3558 void SetMin(int64 m) override {
3559 if (m > left_->Min() + right_->Min()) {
3560 left_->SetMin(m - right_->Max());
3561 right_->SetMin(m - left_->Max());
3562 }
3563 }
3564
3565 void SetRange(int64 l, int64 u) override {
3566 const int64 left_min = left_->Min();
3567 const int64 right_min = right_->Min();
3568 const int64 left_max = left_->Max();
3569 const int64 right_max = right_->Max();
3570 if (l > left_min + right_min) {
3571 left_->SetMin(l - right_max);
3572 right_->SetMin(l - left_max);
3573 }
3574 if (u < left_max + right_max) {
3575 left_->SetMax(u - right_min);
3576 right_->SetMax(u - left_min);
3577 }
3578 }
3579
3580 int64 Max() const override { return left_->Max() + right_->Max(); }
3581
3582 void SetMax(int64 m) override {
3583 if (m < left_->Max() + right_->Max()) {
3584 left_->SetMax(m - right_->Min());
3585 right_->SetMax(m - left_->Min());
3586 }
3587 }
3588
3589 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3590
3591 void Range(int64* const mi, int64* const ma) override {
3592 *mi = left_->Min() + right_->Min();
3593 *ma = left_->Max() + right_->Max();
3594 }
3595
3596 std::string name() const override {
3597 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3598 }
3599
3600 std::string DebugString() const override {
3601 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3602 right_->DebugString());
3603 }
3604
3605 void WhenRange(Demon* d) override {
3606 left_->WhenRange(d);
3607 right_->WhenRange(d);
3608 }
3609
3610 void ExpandPlusIntExpr(IntExpr* const expr, std::vector<IntExpr*>* subs) {
3611 PlusIntExpr* const casted = dynamic_cast<PlusIntExpr*>(expr);
3612 if (casted != nullptr) {
3613 ExpandPlusIntExpr(casted->left_, subs);
3614 ExpandPlusIntExpr(casted->right_, subs);
3615 } else {
3616 subs->push_back(expr);
3617 }
3618 }
3619
3620 IntVar* CastToVar() override {
3621 if (dynamic_cast<PlusIntExpr*>(left_) != nullptr ||
3622 dynamic_cast<PlusIntExpr*>(right_) != nullptr) {
3623 std::vector<IntExpr*> sub_exprs;
3624 ExpandPlusIntExpr(left_, &sub_exprs);
3625 ExpandPlusIntExpr(right_, &sub_exprs);
3626 if (sub_exprs.size() >= 3) {
3627 std::vector<IntVar*> sub_vars(sub_exprs.size());
3628 for (int i = 0; i < sub_exprs.size(); ++i) {
3629 sub_vars[i] = sub_exprs[i]->Var();
3630 }
3631 return solver()->MakeSum(sub_vars)->Var();
3632 }
3633 }
3634 return BaseIntExpr::CastToVar();
3635 }
3636
3637 void Accept(ModelVisitor* const visitor) const override {
3638 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3639 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3640 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3641 right_);
3642 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3643 }
3644
3645 private:
3646 IntExpr* const left_;
3647 IntExpr* const right_;
3648};
3649
3650class SafePlusIntExpr : public BaseIntExpr {
3651 public:
3652 SafePlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3653 : BaseIntExpr(s), left_(l), right_(r) {}
3654
3655 ~SafePlusIntExpr() override {}
3656
3657 int64 Min() const override { return CapAdd(left_->Min(), right_->Min()); }
3658
3659 void SetMin(int64 m) override {
3660 left_->SetMin(CapSub(m, right_->Max()));
3661 right_->SetMin(CapSub(m, left_->Max()));
3662 }
3663
3664 void SetRange(int64 l, int64 u) override {
3665 const int64 left_min = left_->Min();
3666 const int64 right_min = right_->Min();
3667 const int64 left_max = left_->Max();
3668 const int64 right_max = right_->Max();
3669 if (l > CapAdd(left_min, right_min)) {
3670 left_->SetMin(CapSub(l, right_max));
3671 right_->SetMin(CapSub(l, left_max));
3672 }
3673 if (u < CapAdd(left_max, right_max)) {
3674 left_->SetMax(CapSub(u, right_min));
3675 right_->SetMax(CapSub(u, left_min));
3676 }
3677 }
3678
3679 int64 Max() const override { return CapAdd(left_->Max(), right_->Max()); }
3680
3681 void SetMax(int64 m) override {
3682 left_->SetMax(CapSub(m, right_->Min()));
3683 right_->SetMax(CapSub(m, left_->Min()));
3684 }
3685
3686 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3687
3688 std::string name() const override {
3689 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3690 }
3691
3692 std::string DebugString() const override {
3693 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3694 right_->DebugString());
3695 }
3696
3697 void WhenRange(Demon* d) override {
3698 left_->WhenRange(d);
3699 right_->WhenRange(d);
3700 }
3701
3702 void Accept(ModelVisitor* const visitor) const override {
3703 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3704 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3705 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3706 right_);
3707 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3708 }
3709
3710 private:
3711 IntExpr* const left_;
3712 IntExpr* const right_;
3713};
3714
3715// ----- PlusIntCstExpr -----
3716
3717class PlusIntCstExpr : public BaseIntExpr {
3718 public:
3719 PlusIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
3720 : BaseIntExpr(s), expr_(e), value_(v) {}
3721 ~PlusIntCstExpr() override {}
3722 int64 Min() const override { return CapAdd(expr_->Min(), value_); }
3723 void SetMin(int64 m) override { expr_->SetMin(CapSub(m, value_)); }
3724 int64 Max() const override { return CapAdd(expr_->Max(), value_); }
3725 void SetMax(int64 m) override { expr_->SetMax(CapSub(m, value_)); }
3726 bool Bound() const override { return (expr_->Bound()); }
3727 std::string name() const override {
3728 return absl::StrFormat("(%s + %d)", expr_->name(), value_);
3729 }
3730 std::string DebugString() const override {
3731 return absl::StrFormat("(%s + %d)", expr_->DebugString(), value_);
3732 }
3733 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3734 IntVar* CastToVar() override;
3735 void Accept(ModelVisitor* const visitor) const override {
3736 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3737 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3738 expr_);
3739 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3740 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3741 }
3742
3743 private:
3744 IntExpr* const expr_;
3745 const int64 value_;
3746};
3747
3748IntVar* PlusIntCstExpr::CastToVar() {
3749 Solver* const s = solver();
3750 IntVar* const var = expr_->Var();
3751 IntVar* cast = nullptr;
3752 if (AddOverflows(value_, expr_->Max()) ||
3753 AddOverflows(value_, expr_->Min())) {
3754 return BaseIntExpr::CastToVar();
3755 }
3756 switch (var->VarType()) {
3757 case DOMAIN_INT_VAR:
3758 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstDomainIntVar(
3759 s, reinterpret_cast<DomainIntVar*>(var), value_)));
3760 // FIXME: Break was inserted during fallthrough cleanup. Please check.
3761 break;
3762 default:
3763 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstIntVar(s, var, value_)));
3764 break;
3765 }
3766 return cast;
3767}
3768
3769// ----- SubIntExpr -----
3770
3771class SubIntExpr : public BaseIntExpr {
3772 public:
3773 SubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3774 : BaseIntExpr(s), left_(l), right_(r) {}
3775
3776 ~SubIntExpr() override {}
3777
3778 int64 Min() const override { return left_->Min() - right_->Max(); }
3779
3780 void SetMin(int64 m) override {
3781 left_->SetMin(CapAdd(m, right_->Min()));
3782 right_->SetMax(CapSub(left_->Max(), m));
3783 }
3784
3785 int64 Max() const override { return left_->Max() - right_->Min(); }
3786
3787 void SetMax(int64 m) override {
3788 left_->SetMax(CapAdd(m, right_->Max()));
3789 right_->SetMin(CapSub(left_->Min(), m));
3790 }
3791
3792 void Range(int64* mi, int64* ma) override {
3793 *mi = left_->Min() - right_->Max();
3794 *ma = left_->Max() - right_->Min();
3795 }
3796
3797 void SetRange(int64 l, int64 u) override {
3798 const int64 left_min = left_->Min();
3799 const int64 right_min = right_->Min();
3800 const int64 left_max = left_->Max();
3801 const int64 right_max = right_->Max();
3802 if (l > left_min - right_max) {
3803 left_->SetMin(CapAdd(l, right_min));
3804 right_->SetMax(CapSub(left_max, l));
3805 }
3806 if (u < left_max - right_min) {
3807 left_->SetMax(CapAdd(u, right_max));
3808 right_->SetMin(CapSub(left_min, u));
3809 }
3810 }
3811
3812 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3813
3814 std::string name() const override {
3815 return absl::StrFormat("(%s - %s)", left_->name(), right_->name());
3816 }
3817
3818 std::string DebugString() const override {
3819 return absl::StrFormat("(%s - %s)", left_->DebugString(),
3820 right_->DebugString());
3821 }
3822
3823 void WhenRange(Demon* d) override {
3824 left_->WhenRange(d);
3825 right_->WhenRange(d);
3826 }
3827
3828 void Accept(ModelVisitor* const visitor) const override {
3829 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3830 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3831 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3832 right_);
3833 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3834 }
3835
3836 IntExpr* left() const { return left_; }
3837 IntExpr* right() const { return right_; }
3838
3839 protected:
3840 IntExpr* const left_;
3841 IntExpr* const right_;
3842};
3843
3844class SafeSubIntExpr : public SubIntExpr {
3845 public:
3846 SafeSubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3847 : SubIntExpr(s, l, r) {}
3848
3849 ~SafeSubIntExpr() override {}
3850
3851 int64 Min() const override { return CapSub(left_->Min(), right_->Max()); }
3852
3853 void SetMin(int64 m) override {
3854 left_->SetMin(CapAdd(m, right_->Min()));
3855 right_->SetMax(CapSub(left_->Max(), m));
3856 }
3857
3858 void SetRange(int64 l, int64 u) override {
3859 const int64 left_min = left_->Min();
3860 const int64 right_min = right_->Min();
3861 const int64 left_max = left_->Max();
3862 const int64 right_max = right_->Max();
3863 if (l > CapSub(left_min, right_max)) {
3864 left_->SetMin(CapAdd(l, right_min));
3865 right_->SetMax(CapSub(left_max, l));
3866 }
3867 if (u < CapSub(left_max, right_min)) {
3868 left_->SetMax(CapAdd(u, right_max));
3869 right_->SetMin(CapSub(left_min, u));
3870 }
3871 }
3872
3873 void Range(int64* mi, int64* ma) override {
3874 *mi = CapSub(left_->Min(), right_->Max());
3875 *ma = CapSub(left_->Max(), right_->Min());
3876 }
3877
3878 int64 Max() const override { return CapSub(left_->Max(), right_->Min()); }
3879
3880 void SetMax(int64 m) override {
3881 left_->SetMax(CapAdd(m, right_->Max()));
3882 right_->SetMin(CapSub(left_->Min(), m));
3883 }
3884};
3885
3886// l - r
3887
3888// ----- SubIntCstExpr -----
3889
3890class SubIntCstExpr : public BaseIntExpr {
3891 public:
3892 SubIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
3893 : BaseIntExpr(s), expr_(e), value_(v) {}
3894 ~SubIntCstExpr() override {}
3895 int64 Min() const override { return CapSub(value_, expr_->Max()); }
3896 void SetMin(int64 m) override { expr_->SetMax(CapSub(value_, m)); }
3897 int64 Max() const override { return CapSub(value_, expr_->Min()); }
3898 void SetMax(int64 m) override { expr_->SetMin(CapSub(value_, m)); }
3899 bool Bound() const override { return (expr_->Bound()); }
3900 std::string name() const override {
3901 return absl::StrFormat("(%d - %s)", value_, expr_->name());
3902 }
3903 std::string DebugString() const override {
3904 return absl::StrFormat("(%d - %s)", value_, expr_->DebugString());
3905 }
3906 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3907 IntVar* CastToVar() override;
3908
3909 void Accept(ModelVisitor* const visitor) const override {
3910 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3911 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3912 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3913 expr_);
3914 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3915 }
3916
3917 private:
3918 IntExpr* const expr_;
3919 const int64 value_;
3920};
3921
3922IntVar* SubIntCstExpr::CastToVar() {
3923 if (SubOverflows(value_, expr_->Min()) ||
3924 SubOverflows(value_, expr_->Max())) {
3925 return BaseIntExpr::CastToVar();
3926 }
3927 Solver* const s = solver();
3928 IntVar* const var =
3929 s->RegisterIntVar(s->RevAlloc(new SubCstIntVar(s, expr_->Var(), value_)));
3930 return var;
3931}
3932
3933// ----- OppIntExpr -----
3934
3935class OppIntExpr : public BaseIntExpr {
3936 public:
3937 OppIntExpr(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
3938 ~OppIntExpr() override {}
3939 int64 Min() const override { return (-expr_->Max()); }
3940 void SetMin(int64 m) override { expr_->SetMax(-m); }
3941 int64 Max() const override { return (-expr_->Min()); }
3942 void SetMax(int64 m) override { expr_->SetMin(-m); }
3943 bool Bound() const override { return (expr_->Bound()); }
3944 std::string name() const override {
3945 return absl::StrFormat("(-%s)", expr_->name());
3946 }
3947 std::string DebugString() const override {
3948 return absl::StrFormat("(-%s)", expr_->DebugString());
3949 }
3950 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3951 IntVar* CastToVar() override;
3952
3953 void Accept(ModelVisitor* const visitor) const override {
3954 visitor->BeginVisitIntegerExpression(ModelVisitor::kOpposite, this);
3955 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3956 expr_);
3957 visitor->EndVisitIntegerExpression(ModelVisitor::kOpposite, this);
3958 }
3959
3960 private:
3961 IntExpr* const expr_;
3962};
3963
3964IntVar* OppIntExpr::CastToVar() {
3965 Solver* const s = solver();
3966 IntVar* const var =
3967 s->RegisterIntVar(s->RevAlloc(new OppIntVar(s, expr_->Var())));
3968 return var;
3969}
3970
3971// ----- TimesIntCstExpr -----
3972
3973class TimesIntCstExpr : public BaseIntExpr {
3974 public:
3975 TimesIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
3976 : BaseIntExpr(s), expr_(e), value_(v) {}
3977
3978 ~TimesIntCstExpr() override {}
3979
3980 bool Bound() const override { return (expr_->Bound()); }
3981
3982 std::string name() const override {
3983 return absl::StrFormat("(%s * %d)", expr_->name(), value_);
3984 }
3985
3986 std::string DebugString() const override {
3987 return absl::StrFormat("(%s * %d)", expr_->DebugString(), value_);
3988 }
3989
3990 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3991
3992 IntExpr* Expr() const { return expr_; }
3993
3994 int64 Constant() const { return value_; }
3995
3996 void Accept(ModelVisitor* const visitor) const override {
3997 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
3998 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3999 expr_);
4000 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4001 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4002 }
4003
4004 protected:
4005 IntExpr* const expr_;
4006 const int64 value_;
4007};
4008
4009// ----- TimesPosIntCstExpr -----
4010
4011class TimesPosIntCstExpr : public TimesIntCstExpr {
4012 public:
4013 TimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
4014 : TimesIntCstExpr(s, e, v) {
4015 CHECK_GT(v, 0);
4016 }
4017
4018 ~TimesPosIntCstExpr() override {}
4019
4020 int64 Min() const override { return expr_->Min() * value_; }
4021
4022 void SetMin(int64 m) override { expr_->SetMin(PosIntDivUp(m, value_)); }
4023
4024 int64 Max() const override { return expr_->Max() * value_; }
4025
4026 void SetMax(int64 m) override { expr_->SetMax(PosIntDivDown(m, value_)); }
4027
4028 IntVar* CastToVar() override {
4029 Solver* const s = solver();
4030 IntVar* var = nullptr;
4031 if (expr_->IsVar() &&
4032 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4033 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4034 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4035 } else {
4036 var = s->RegisterIntVar(
4037 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4038 }
4039 return var;
4040 }
4041};
4042
4043// This expressions adds safe arithmetic (w.r.t. overflows) compared
4044// to the previous one.
4045class SafeTimesPosIntCstExpr : public TimesIntCstExpr {
4046 public:
4047 SafeTimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
4048 : TimesIntCstExpr(s, e, v) {
4049 CHECK_GT(v, 0);
4050 }
4051
4052 ~SafeTimesPosIntCstExpr() override {}
4053
4054 int64 Min() const override { return CapProd(expr_->Min(), value_); }
4055
4056 void SetMin(int64 m) override {
4057 if (m != kint64min) {
4058 expr_->SetMin(PosIntDivUp(m, value_));
4059 }
4060 }
4061
4062 int64 Max() const override { return CapProd(expr_->Max(), value_); }
4063
4064 void SetMax(int64 m) override {
4065 if (m != kint64max) {
4066 expr_->SetMax(PosIntDivDown(m, value_));
4067 }
4068 }
4069
4070 IntVar* CastToVar() override {
4071 Solver* const s = solver();
4072 IntVar* var = nullptr;
4073 if (expr_->IsVar() &&
4074 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4075 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4076 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4077 } else {
4078 // TODO(user): Check overflows.
4079 var = s->RegisterIntVar(
4080 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4081 }
4082 return var;
4083 }
4084};
4085
4086// ----- TimesIntNegCstExpr -----
4087
4088class TimesIntNegCstExpr : public TimesIntCstExpr {
4089 public:
4090 TimesIntNegCstExpr(Solver* const s, IntExpr* const e, int64 v)
4091 : TimesIntCstExpr(s, e, v) {
4092 CHECK_LT(v, 0);
4093 }
4094
4095 ~TimesIntNegCstExpr() override {}
4096
4097 int64 Min() const override { return CapProd(expr_->Max(), value_); }
4098
4099 void SetMin(int64 m) override {
4100 if (m != kint64min) {
4101 expr_->SetMax(PosIntDivDown(-m, -value_));
4102 }
4103 }
4104
4105 int64 Max() const override { return CapProd(expr_->Min(), value_); }
4106
4107 void SetMax(int64 m) override {
4108 if (m != kint64max) {
4109 expr_->SetMin(PosIntDivUp(-m, -value_));
4110 }
4111 }
4112
4113 IntVar* CastToVar() override {
4114 Solver* const s = solver();
4115 IntVar* var = nullptr;
4116 var = s->RegisterIntVar(
4117 s->RevAlloc(new TimesNegCstIntVar(s, expr_->Var(), value_)));
4118 return var;
4119 }
4120};
4121
4122// ----- Utilities for product expression -----
4123
4124// Propagates set_min on left * right, left and right >= 0.
4125void SetPosPosMinExpr(IntExpr* const left, IntExpr* const right, int64 m) {
4126 DCHECK_GE(left->Min(), 0);
4127 DCHECK_GE(right->Min(), 0);
4128 const int64 lmax = left->Max();
4129 const int64 rmax = right->Max();
4130 if (m > CapProd(lmax, rmax)) {
4131 left->solver()->Fail();
4132 }
4133 if (m > CapProd(left->Min(), right->Min())) {
4134 // Ok for m == 0 due to left and right being positive
4135 if (0 != rmax) {
4136 left->SetMin(PosIntDivUp(m, rmax));
4137 }
4138 if (0 != lmax) {
4139 right->SetMin(PosIntDivUp(m, lmax));
4140 }
4141 }
4142}
4143
4144// Propagates set_max on left * right, left and right >= 0.
4145void SetPosPosMaxExpr(IntExpr* const left, IntExpr* const right, int64 m) {
4146 DCHECK_GE(left->Min(), 0);
4147 DCHECK_GE(right->Min(), 0);
4148 const int64 lmin = left->Min();
4149 const int64 rmin = right->Min();
4150 if (m < CapProd(lmin, rmin)) {
4151 left->solver()->Fail();
4152 }
4153 if (m < CapProd(left->Max(), right->Max())) {
4154 if (0 != lmin) {
4155 right->SetMax(PosIntDivDown(m, lmin));
4156 }
4157 if (0 != rmin) {
4158 left->SetMax(PosIntDivDown(m, rmin));
4159 }
4160 // else do nothing: 0 is supporting any value from other expr.
4161 }
4162}
4163
4164// Propagates set_min on left * right, left >= 0, right across 0.
4165void SetPosGenMinExpr(IntExpr* const left, IntExpr* const right, int64 m) {
4166 DCHECK_GE(left->Min(), 0);
4167 DCHECK_GT(right->Max(), 0);
4168 DCHECK_LT(right->Min(), 0);
4169 const int64 lmax = left->Max();
4170 const int64 rmax = right->Max();
4171 if (m > CapProd(lmax, rmax)) {
4172 left->solver()->Fail();
4173 }
4174 if (left->Max() == 0) { // left is bound to 0, product is bound to 0.
4175 DCHECK_EQ(0, left->Min());
4176 DCHECK_LE(m, 0);
4177 } else {
4178 if (m > 0) { // We deduce right > 0.
4179 left->SetMin(PosIntDivUp(m, rmax));
4180 right->SetMin(PosIntDivUp(m, lmax));
4181 } else if (m == 0) {
4182 const int64 lmin = left->Min();
4183 if (lmin > 0) {
4184 right->SetMin(0);
4185 }
4186 } else { // m < 0
4187 const int64 lmin = left->Min();
4188 if (0 != lmin) { // We cannot deduce anything if 0 is in the domain.
4189 right->SetMin(-PosIntDivDown(-m, lmin));
4190 }
4191 }
4192 }
4193}
4194
4195// Propagates set_min on left * right, left and right across 0.
4196void SetGenGenMinExpr(IntExpr* const left, IntExpr* const right, int64 m) {
4197 DCHECK_LT(left->Min(), 0);
4198 DCHECK_GT(left->Max(), 0);
4199 DCHECK_GT(right->Max(), 0);
4200 DCHECK_LT(right->Min(), 0);
4201 const int64 lmin = left->Min();
4202 const int64 lmax = left->Max();
4203 const int64 rmin = right->Min();
4204 const int64 rmax = right->Max();
4205 if (m > std::max(CapProd(lmin, rmin), CapProd(lmax, rmax))) {
4206 left->solver()->Fail();
4207 }
4208 if (m > lmin * rmin) { // Must be positive section * positive section.
4209 left->SetMin(PosIntDivUp(m, rmax));
4210 right->SetMin(PosIntDivUp(m, lmax));
4211 } else if (m > CapProd(lmax, rmax)) { // Negative section * negative section.
4212 left->SetMax(-PosIntDivUp(m, -rmin));
4213 right->SetMax(-PosIntDivUp(m, -lmin));
4214 }
4215}
4216
4217void TimesSetMin(IntExpr* const left, IntExpr* const right,
4218 IntExpr* const minus_left, IntExpr* const minus_right,
4219 int64 m) {
4220 if (left->Min() >= 0) {
4221 if (right->Min() >= 0) {
4222 SetPosPosMinExpr(left, right, m);
4223 } else if (right->Max() <= 0) {
4224 SetPosPosMaxExpr(left, minus_right, -m);
4225 } else { // right->Min() < 0 && right->Max() > 0
4226 SetPosGenMinExpr(left, right, m);
4227 }
4228 } else if (left->Max() <= 0) {
4229 if (right->Min() >= 0) {
4230 SetPosPosMaxExpr(right, minus_left, -m);
4231 } else if (right->Max() <= 0) {
4232 SetPosPosMinExpr(minus_left, minus_right, m);
4233 } else { // right->Min() < 0 && right->Max() > 0
4234 SetPosGenMinExpr(minus_left, minus_right, m);
4235 }
4236 } else if (right->Min() >= 0) { // left->Min() < 0 && left->Max() > 0
4237 SetPosGenMinExpr(right, left, m);
4238 } else if (right->Max() <= 0) { // left->Min() < 0 && left->Max() > 0
4239 SetPosGenMinExpr(minus_right, minus_left, m);
4240 } else { // left->Min() < 0 && left->Max() > 0 &&
4241 // right->Min() < 0 && right->Max() > 0
4242 SetGenGenMinExpr(left, right, m);
4243 }
4244}
4245
4246class TimesIntExpr : public BaseIntExpr {
4247 public:
4248 TimesIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4249 : BaseIntExpr(s),
4250 left_(l),
4251 right_(r),
4252 minus_left_(s->MakeOpposite(left_)),
4253 minus_right_(s->MakeOpposite(right_)) {}
4254 ~TimesIntExpr() override {}
4255 int64 Min() const override {
4256 const int64 lmin = left_->Min();
4257 const int64 lmax = left_->Max();
4258 const int64 rmin = right_->Min();
4259 const int64 rmax = right_->Max();
4260 return std::min(std::min(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4261 std::min(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4262 }
4263 void SetMin(int64 m) override;
4264 int64 Max() const override {
4265 const int64 lmin = left_->Min();
4266 const int64 lmax = left_->Max();
4267 const int64 rmin = right_->Min();
4268 const int64 rmax = right_->Max();
4269 return std::max(std::max(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4270 std::max(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4271 }
4272 void SetMax(int64 m) override;
4273 bool Bound() const override;
4274 std::string name() const override {
4275 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4276 }
4277 std::string DebugString() const override {
4278 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4279 right_->DebugString());
4280 }
4281 void WhenRange(Demon* d) override {
4282 left_->WhenRange(d);
4283 right_->WhenRange(d);
4284 }
4285
4286 void Accept(ModelVisitor* const visitor) const override {
4287 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4288 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4289 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4290 right_);
4291 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4292 }
4293
4294 private:
4295 IntExpr* const left_;
4296 IntExpr* const right_;
4297 IntExpr* const minus_left_;
4298 IntExpr* const minus_right_;
4299};
4300
4301void TimesIntExpr::SetMin(int64 m) {
4302 if (m != kint64min) {
4303 TimesSetMin(left_, right_, minus_left_, minus_right_, m);
4304 }
4305}
4306
4307void TimesIntExpr::SetMax(int64 m) {
4308 if (m != kint64max) {
4309 TimesSetMin(left_, minus_right_, minus_left_, right_, -m);
4310 }
4311}
4312
4313bool TimesIntExpr::Bound() const {
4314 const bool left_bound = left_->Bound();
4315 const bool right_bound = right_->Bound();
4316 return ((left_bound && left_->Max() == 0) ||
4317 (right_bound && right_->Max() == 0) || (left_bound && right_bound));
4318}
4319
4320// ----- TimesPosIntExpr -----
4321
4322class TimesPosIntExpr : public BaseIntExpr {
4323 public:
4324 TimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4325 : BaseIntExpr(s), left_(l), right_(r) {}
4326 ~TimesPosIntExpr() override {}
4327 int64 Min() const override { return (left_->Min() * right_->Min()); }
4328 void SetMin(int64 m) override;
4329 int64 Max() const override { return (left_->Max() * right_->Max()); }
4330 void SetMax(int64 m) override;
4331 bool Bound() const override;
4332 std::string name() const override {
4333 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4334 }
4335 std::string DebugString() const override {
4336 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4337 right_->DebugString());
4338 }
4339 void WhenRange(Demon* d) override {
4340 left_->WhenRange(d);
4341 right_->WhenRange(d);
4342 }
4343
4344 void Accept(ModelVisitor* const visitor) const override {
4345 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4346 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4347 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4348 right_);
4349 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4350 }
4351
4352 private:
4353 IntExpr* const left_;
4354 IntExpr* const right_;
4355};
4356
4357void TimesPosIntExpr::SetMin(int64 m) { SetPosPosMinExpr(left_, right_, m); }
4358
4359void TimesPosIntExpr::SetMax(int64 m) { SetPosPosMaxExpr(left_, right_, m); }
4360
4361bool TimesPosIntExpr::Bound() const {
4362 return (left_->Max() == 0 || right_->Max() == 0 ||
4363 (left_->Bound() && right_->Bound()));
4364}
4365
4366// ----- SafeTimesPosIntExpr -----
4367
4368class SafeTimesPosIntExpr : public BaseIntExpr {
4369 public:
4370 SafeTimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4371 : BaseIntExpr(s), left_(l), right_(r) {}
4372 ~SafeTimesPosIntExpr() override {}
4373 int64 Min() const override { return CapProd(left_->Min(), right_->Min()); }
4374 void SetMin(int64 m) override {
4375 if (m != kint64min) {
4376 SetPosPosMinExpr(left_, right_, m);
4377 }
4378 }
4379 int64 Max() const override { return CapProd(left_->Max(), right_->Max()); }
4380 void SetMax(int64 m) override {
4381 if (m != kint64max) {
4382 SetPosPosMaxExpr(left_, right_, m);
4383 }
4384 }
4385 bool Bound() const override {
4386 return (left_->Max() == 0 || right_->Max() == 0 ||
4387 (left_->Bound() && right_->Bound()));
4388 }
4389 std::string name() const override {
4390 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4391 }
4392 std::string DebugString() const override {
4393 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4394 right_->DebugString());
4395 }
4396 void WhenRange(Demon* d) override {
4397 left_->WhenRange(d);
4398 right_->WhenRange(d);
4399 }
4400
4401 void Accept(ModelVisitor* const visitor) const override {
4402 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4403 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4404 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4405 right_);
4406 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4407 }
4408
4409 private:
4410 IntExpr* const left_;
4411 IntExpr* const right_;
4412};
4413
4414// ----- TimesBooleanPosIntExpr -----
4415
4416class TimesBooleanPosIntExpr : public BaseIntExpr {
4417 public:
4418 TimesBooleanPosIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4419 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4420 ~TimesBooleanPosIntExpr() override {}
4421 int64 Min() const override {
4422 return (boolvar_->RawValue() == 1 ? expr_->Min() : 0);
4423 }
4424 void SetMin(int64 m) override;
4425 int64 Max() const override {
4426 return (boolvar_->RawValue() == 0 ? 0 : expr_->Max());
4427 }
4428 void SetMax(int64 m) override;
4429 void Range(int64* mi, int64* ma) override;
4430 void SetRange(int64 mi, int64 ma) override;
4431 bool Bound() const override;
4432 std::string name() const override {
4433 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4434 }
4435 std::string DebugString() const override {
4436 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4437 expr_->DebugString());
4438 }
4439 void WhenRange(Demon* d) override {
4440 boolvar_->WhenRange(d);
4441 expr_->WhenRange(d);
4442 }
4443
4444 void Accept(ModelVisitor* const visitor) const override {
4445 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4446 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4447 boolvar_);
4448 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4449 expr_);
4450 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4451 }
4452
4453 private:
4454 BooleanVar* const boolvar_;
4455 IntExpr* const expr_;
4456};
4457
4458void TimesBooleanPosIntExpr::SetMin(int64 m) {
4459 if (m > 0) {
4460 boolvar_->SetValue(1);
4461 expr_->SetMin(m);
4462 }
4463}
4464
4465void TimesBooleanPosIntExpr::SetMax(int64 m) {
4466 if (m < 0) {
4467 solver()->Fail();
4468 }
4469 if (m < expr_->Min()) {
4470 boolvar_->SetValue(0);
4471 }
4472 if (boolvar_->RawValue() == 1) {
4473 expr_->SetMax(m);
4474 }
4475}
4476
4477void TimesBooleanPosIntExpr::Range(int64* mi, int64* ma) {
4478 const int value = boolvar_->RawValue();
4479 if (value == 0) {
4480 *mi = 0;
4481 *ma = 0;
4482 } else if (value == 1) {
4483 expr_->Range(mi, ma);
4484 } else {
4485 *mi = 0;
4486 *ma = expr_->Max();
4487 }
4488}
4489
4490void TimesBooleanPosIntExpr::SetRange(int64 mi, int64 ma) {
4491 if (ma < 0 || mi > ma) {
4492 solver()->Fail();
4493 }
4494 if (mi > 0) {
4495 boolvar_->SetValue(1);
4496 expr_->SetMin(mi);
4497 }
4498 if (ma < expr_->Min()) {
4499 boolvar_->SetValue(0);
4500 }
4501 if (boolvar_->RawValue() == 1) {
4502 expr_->SetMax(ma);
4503 }
4504}
4505
4506bool TimesBooleanPosIntExpr::Bound() const {
4507 return (boolvar_->RawValue() == 0 || expr_->Max() == 0 ||
4508 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue &&
4509 expr_->Bound()));
4510}
4511
4512// ----- TimesBooleanIntExpr -----
4513
4514class TimesBooleanIntExpr : public BaseIntExpr {
4515 public:
4516 TimesBooleanIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4517 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4518 ~TimesBooleanIntExpr() override {}
4519 int64 Min() const override {
4520 switch (boolvar_->RawValue()) {
4521 case 0: {
4522 return 0LL;
4523 }
4524 case 1: {
4525 return expr_->Min();
4526 }
4527 default: {
4528 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4529 return std::min(int64{0}, expr_->Min());
4530 }
4531 }
4532 }
4533 void SetMin(int64 m) override;
4534 int64 Max() const override {
4535 switch (boolvar_->RawValue()) {
4536 case 0: {
4537 return 0LL;
4538 }
4539 case 1: {
4540 return expr_->Max();
4541 }
4542 default: {
4543 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4544 return std::max(int64{0}, expr_->Max());
4545 }
4546 }
4547 }
4548 void SetMax(int64 m) override;
4549 void Range(int64* mi, int64* ma) override;
4550 void SetRange(int64 mi, int64 ma) override;
4551 bool Bound() const override;
4552 std::string name() const override {
4553 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4554 }
4555 std::string DebugString() const override {
4556 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4557 expr_->DebugString());
4558 }
4559 void WhenRange(Demon* d) override {
4560 boolvar_->WhenRange(d);
4561 expr_->WhenRange(d);
4562 }
4563
4564 void Accept(ModelVisitor* const visitor) const override {
4565 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4566 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4567 boolvar_);
4568 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4569 expr_);
4570 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4571 }
4572
4573 private:
4574 BooleanVar* const boolvar_;
4575 IntExpr* const expr_;
4576};
4577
4578void TimesBooleanIntExpr::SetMin(int64 m) {
4579 switch (boolvar_->RawValue()) {
4580 case 0: {
4581 if (m > 0) {
4582 solver()->Fail();
4583 }
4584 break;
4585 }
4586 case 1: {
4587 expr_->SetMin(m);
4588 break;
4589 }
4590 default: {
4591 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4592 if (m > 0) { // 0 is no longer possible for boolvar because min > 0.
4593 boolvar_->SetValue(1);
4594 expr_->SetMin(m);
4595 } else if (m <= 0 && expr_->Max() < m) {
4596 boolvar_->SetValue(0);
4597 }
4598 }
4599 }
4600}
4601
4602void TimesBooleanIntExpr::SetMax(int64 m) {
4603 switch (boolvar_->RawValue()) {
4604 case 0: {
4605 if (m < 0) {
4606 solver()->Fail();
4607 }
4608 break;
4609 }
4610 case 1: {
4611 expr_->SetMax(m);
4612 break;
4613 }
4614 default: {
4615 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4616 if (m < 0) { // 0 is no longer possible for boolvar because max < 0.
4617 boolvar_->SetValue(1);
4618 expr_->SetMax(m);
4619 } else if (m >= 0 && expr_->Min() > m) {
4620 boolvar_->SetValue(0);
4621 }
4622 }
4623 }
4624}
4625
4626void TimesBooleanIntExpr::Range(int64* mi, int64* ma) {
4627 switch (boolvar_->RawValue()) {
4628 case 0: {
4629 *mi = 0;
4630 *ma = 0;
4631 break;
4632 }
4633 case 1: {
4634 *mi = expr_->Min();
4635 *ma = expr_->Max();
4636 break;
4637 }
4638 default: {
4639 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4640 *mi = std::min(int64{0}, expr_->Min());
4641 *ma = std::max(int64{0}, expr_->Max());
4642 break;
4643 }
4644 }
4645}
4646
4647void TimesBooleanIntExpr::SetRange(int64 mi, int64 ma) {
4648 if (mi > ma) {
4649 solver()->Fail();
4650 }
4651 switch (boolvar_->RawValue()) {
4652 case 0: {
4653 if (mi > 0 || ma < 0) {
4654 solver()->Fail();
4655 }
4656 break;
4657 }
4658 case 1: {
4659 expr_->SetRange(mi, ma);
4660 break;
4661 }
4662 default: {
4663 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4664 if (mi > 0) {
4665 boolvar_->SetValue(1);
4666 expr_->SetMin(mi);
4667 } else if (mi == 0 && expr_->Max() < 0) {
4668 boolvar_->SetValue(0);
4669 }
4670 if (ma < 0) {
4671 boolvar_->SetValue(1);
4672 expr_->SetMax(ma);
4673 } else if (ma == 0 && expr_->Min() > 0) {
4674 boolvar_->SetValue(0);
4675 }
4676 break;
4677 }
4678 }
4679}
4680
4681bool TimesBooleanIntExpr::Bound() const {
4682 return (boolvar_->RawValue() == 0 ||
4683 (expr_->Bound() &&
4684 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue ||
4685 expr_->Max() == 0)));
4686}
4687
4688// ----- DivPosIntCstExpr -----
4689
4690class DivPosIntCstExpr : public BaseIntExpr {
4691 public:
4692 DivPosIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
4693 : BaseIntExpr(s), expr_(e), value_(v) {
4694 CHECK_GE(v, 0);
4695 }
4696 ~DivPosIntCstExpr() override {}
4697
4698 int64 Min() const override { return expr_->Min() / value_; }
4699
4700 void SetMin(int64 m) override {
4701 if (m > 0) {
4702 expr_->SetMin(m * value_);
4703 } else {
4704 expr_->SetMin((m - 1) * value_ + 1);
4705 }
4706 }
4707 int64 Max() const override { return expr_->Max() / value_; }
4708
4709 void SetMax(int64 m) override {
4710 if (m >= 0) {
4711 expr_->SetMax((m + 1) * value_ - 1);
4712 } else {
4713 expr_->SetMax(m * value_);
4714 }
4715 }
4716
4717 std::string name() const override {
4718 return absl::StrFormat("(%s div %d)", expr_->name(), value_);
4719 }
4720
4721 std::string DebugString() const override {
4722 return absl::StrFormat("(%s div %d)", expr_->DebugString(), value_);
4723 }
4724
4725 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4726
4727 void Accept(ModelVisitor* const visitor) const override {
4728 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4729 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4730 expr_);
4731 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4732 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4733 }
4734
4735 private:
4736 IntExpr* const expr_;
4737 const int64 value_;
4738};
4739
4740// DivPosIntExpr
4741
4742class DivPosIntExpr : public BaseIntExpr {
4743 public:
4744 DivPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4745 : BaseIntExpr(s),
4746 num_(num),
4747 denom_(denom),
4748 opp_num_(s->MakeOpposite(num)) {}
4749
4750 ~DivPosIntExpr() override {}
4751
4752 int64 Min() const override {
4753 return num_->Min() >= 0
4754 ? num_->Min() / denom_->Max()
4755 : (denom_->Min() == 0 ? num_->Min()
4756 : num_->Min() / denom_->Min());
4757 }
4758
4759 int64 Max() const override {
4760 return num_->Max() >= 0 ? (denom_->Min() == 0 ? num_->Max()
4761 : num_->Max() / denom_->Min())
4762 : num_->Max() / denom_->Max();
4763 }
4764
4765 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64 m) {
4766 num->SetMin(m * denom->Min());
4767 denom->SetMax(num->Max() / m);
4768 }
4769
4770 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64 m) {
4771 num->SetMax((m + 1) * denom->Max() - 1);
4772 denom->SetMin(num->Min() / (m + 1) + 1);
4773 }
4774
4775 void SetMin(int64 m) override {
4776 if (m > 0) {
4777 SetPosMin(num_, denom_, m);
4778 } else {
4779 SetPosMax(opp_num_, denom_, -m);
4780 }
4781 }
4782
4783 void SetMax(int64 m) override {
4784 if (m >= 0) {
4785 SetPosMax(num_, denom_, m);
4786 } else {
4787 SetPosMin(opp_num_, denom_, -m);
4788 }
4789 }
4790
4791 std::string name() const override {
4792 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4793 }
4794 std::string DebugString() const override {
4795 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4796 denom_->DebugString());
4797 }
4798 void WhenRange(Demon* d) override {
4799 num_->WhenRange(d);
4800 denom_->WhenRange(d);
4801 }
4802
4803 void Accept(ModelVisitor* const visitor) const override {
4804 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4805 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4806 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4807 denom_);
4808 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4809 }
4810
4811 private:
4812 IntExpr* const num_;
4813 IntExpr* const denom_;
4814 IntExpr* const opp_num_;
4815};
4816
4817class DivPosPosIntExpr : public BaseIntExpr {
4818 public:
4819 DivPosPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4820 : BaseIntExpr(s), num_(num), denom_(denom) {}
4821
4822 ~DivPosPosIntExpr() override {}
4823
4824 int64 Min() const override {
4825 if (denom_->Max() == 0) {
4826 solver()->Fail();
4827 }
4828 return num_->Min() / denom_->Max();
4829 }
4830
4831 int64 Max() const override {
4832 if (denom_->Min() == 0) {
4833 return num_->Max();
4834 } else {
4835 return num_->Max() / denom_->Min();
4836 }
4837 }
4838
4839 void SetMin(int64 m) override {
4840 if (m > 0) {
4841 num_->SetMin(m * denom_->Min());
4842 denom_->SetMax(num_->Max() / m);
4843 }
4844 }
4845
4846 void SetMax(int64 m) override {
4847 if (m >= 0) {
4848 num_->SetMax((m + 1) * denom_->Max() - 1);
4849 denom_->SetMin(num_->Min() / (m + 1) + 1);
4850 } else {
4851 solver()->Fail();
4852 }
4853 }
4854
4855 std::string name() const override {
4856 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4857 }
4858
4859 std::string DebugString() const override {
4860 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4861 denom_->DebugString());
4862 }
4863
4864 void WhenRange(Demon* d) override {
4865 num_->WhenRange(d);
4866 denom_->WhenRange(d);
4867 }
4868
4869 void Accept(ModelVisitor* const visitor) const override {
4870 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4871 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4872 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4873 denom_);
4874 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4875 }
4876
4877 private:
4878 IntExpr* const num_;
4879 IntExpr* const denom_;
4880};
4881
4882// DivIntExpr
4883
4884class DivIntExpr : public BaseIntExpr {
4885 public:
4886 DivIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4887 : BaseIntExpr(s),
4888 num_(num),
4889 denom_(denom),
4890 opp_num_(s->MakeOpposite(num)) {}
4891
4892 ~DivIntExpr() override {}
4893
4894 int64 Min() const override {
4895 const int64 num_min = num_->Min();
4896 const int64 num_max = num_->Max();
4897 const int64 denom_min = denom_->Min();
4898 const int64 denom_max = denom_->Max();
4899
4900 if (denom_min == 0 && denom_max == 0) {
4901 return kint64max; // TODO(user): Check this convention.
4902 }
4903
4904 if (denom_min >= 0) { // Denominator strictly positive.
4905 DCHECK_GT(denom_max, 0);
4906 const int64 adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4907 return num_min >= 0 ? num_min / denom_max : num_min / adjusted_denom_min;
4908 } else if (denom_max <= 0) { // Denominator strictly negative.
4909 DCHECK_LT(denom_min, 0);
4910 const int64 adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4911 return num_max >= 0 ? num_max / adjusted_denom_max : num_max / denom_min;
4912 } else { // Denominator across 0.
4913 return std::min(num_min, -num_max);
4914 }
4915 }
4916
4917 int64 Max() const override {
4918 const int64 num_min = num_->Min();
4919 const int64 num_max = num_->Max();
4920 const int64 denom_min = denom_->Min();
4921 const int64 denom_max = denom_->Max();
4922
4923 if (denom_min == 0 && denom_max == 0) {
4924 return kint64min; // TODO(user): Check this convention.
4925 }
4926
4927 if (denom_min >= 0) { // Denominator strictly positive.
4928 DCHECK_GT(denom_max, 0);
4929 const int64 adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4930 return num_max >= 0 ? num_max / adjusted_denom_min : num_max / denom_max;
4931 } else if (denom_max <= 0) { // Denominator strictly negative.
4932 DCHECK_LT(denom_min, 0);
4933 const int64 adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4934 return num_min >= 0 ? num_min / denom_min
4935 : -num_min / -adjusted_denom_max;
4936 } else { // Denominator across 0.
4937 return std::max(num_max, -num_min);
4938 }
4939 }
4940
4941 void AdjustDenominator() {
4942 if (denom_->Min() == 0) {
4943 denom_->SetMin(1);
4944 } else if (denom_->Max() == 0) {
4945 denom_->SetMax(-1);
4946 }
4947 }
4948
4949 // m > 0.
4950 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64 m) {
4951 DCHECK_GT(m, 0);
4952 const int64 num_min = num->Min();
4953 const int64 num_max = num->Max();
4954 const int64 denom_min = denom->Min();
4955 const int64 denom_max = denom->Max();
4956 DCHECK_NE(denom_min, 0);
4957 DCHECK_NE(denom_max, 0);
4958 if (denom_min > 0) { // Denominator strictly positive.
4959 num->SetMin(m * denom_min);
4960 denom->SetMax(num_max / m);
4961 } else if (denom_max < 0) { // Denominator strictly negative.
4962 num->SetMax(m * denom_max);
4963 denom->SetMin(num_min / m);
4964 } else { // Denominator across 0.
4965 if (num_min >= 0) {
4966 num->SetMin(m);
4967 denom->SetRange(1, num_max / m);
4968 } else if (num_max <= 0) {
4969 num->SetMax(-m);
4970 denom->SetRange(num_min / m, -1);
4971 } else {
4972 if (m > -num_min) { // Denominator is forced positive.
4973 num->SetMin(m);
4974 denom->SetRange(1, num_max / m);
4975 } else if (m > num_max) { // Denominator is forced negative.
4976 num->SetMax(-m);
4977 denom->SetRange(num_min / m, -1);
4978 } else {
4979 denom->SetRange(num_min / m, num_max / m);
4980 }
4981 }
4982 }
4983 }
4984
4985 // m >= 0.
4986 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64 m) {
4987 DCHECK_GE(m, 0);
4988 const int64 num_min = num->Min();
4989 const int64 num_max = num->Max();
4990 const int64 denom_min = denom->Min();
4991 const int64 denom_max = denom->Max();
4992 DCHECK_NE(denom_min, 0);
4993 DCHECK_NE(denom_max, 0);
4994 if (denom_min > 0) { // Denominator strictly positive.
4995 num->SetMax((m + 1) * denom_max - 1);
4996 denom->SetMin((num_min / (m + 1)) + 1);
4997 } else if (denom_max < 0) {
4998 num->SetMin((m + 1) * denom_min + 1);
4999 denom->SetMax(num_max / (m + 1) - 1);
5000 } else if (num_min > (m + 1) * denom_max - 1) {
5001 denom->SetMax(-1);
5002 } else if (num_max < (m + 1) * denom_min + 1) {
5003 denom->SetMin(1);
5004 }
5005 }
5006
5007 void SetMin(int64 m) override {
5008 AdjustDenominator();
5009 if (m > 0) {
5010 SetPosMin(num_, denom_, m);
5011 } else {
5012 SetPosMax(opp_num_, denom_, -m);
5013 }
5014 }
5015
5016 void SetMax(int64 m) override {
5017 AdjustDenominator();
5018 if (m >= 0) {
5019 SetPosMax(num_, denom_, m);
5020 } else {
5021 SetPosMin(opp_num_, denom_, -m);
5022 }
5023 }
5024
5025 std::string name() const override {
5026 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
5027 }
5028 std::string DebugString() const override {
5029 return absl::StrFormat("(%s div %s)", num_->DebugString(),
5030 denom_->DebugString());
5031 }
5032 void WhenRange(Demon* d) override {
5033 num_->WhenRange(d);
5034 denom_->WhenRange(d);
5035 }
5036
5037 void Accept(ModelVisitor* const visitor) const override {
5038 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
5039 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
5040 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5041 denom_);
5042 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
5043 }
5044
5045 private:
5046 IntExpr* const num_;
5047 IntExpr* const denom_;
5048 IntExpr* const opp_num_;
5049};
5050
5051// ----- IntAbs And IntAbsConstraint ------
5052
5053class IntAbsConstraint : public CastConstraint {
5054 public:
5055 IntAbsConstraint(Solver* const s, IntVar* const sub, IntVar* const target)
5056 : CastConstraint(s, target), sub_(sub) {}
5057
5058 ~IntAbsConstraint() override {}
5059
5060 void Post() override {
5061 Demon* const sub_demon = MakeConstraintDemon0(
5062 solver(), this, &IntAbsConstraint::PropagateSub, "PropagateSub");
5063 sub_->WhenRange(sub_demon);
5064 Demon* const target_demon = MakeConstraintDemon0(
5065 solver(), this, &IntAbsConstraint::PropagateTarget, "PropagateTarget");
5066 target_var_->WhenRange(target_demon);
5067 }
5068
5069 void InitialPropagate() override {
5070 PropagateSub();
5071 PropagateTarget();
5072 }
5073
5074 void PropagateSub() {
5075 const int64 smin = sub_->Min();
5076 const int64 smax = sub_->Max();
5077 if (smax <= 0) {
5078 target_var_->SetRange(-smax, -smin);
5079 } else if (smin >= 0) {
5080 target_var_->SetRange(smin, smax);
5081 } else {
5082 target_var_->SetRange(0, std::max(-smin, smax));
5083 }
5084 }
5085
5086 void PropagateTarget() {
5087 const int64 target_max = target_var_->Max();
5088 sub_->SetRange(-target_max, target_max);
5089 const int64 target_min = target_var_->Min();
5090 if (target_min > 0) {
5091 if (sub_->Min() > -target_min) {
5092 sub_->SetMin(target_min);
5093 } else if (sub_->Max() < target_min) {
5094 sub_->SetMax(-target_min);
5095 }
5096 }
5097 }
5098
5099 std::string DebugString() const override {
5100 return absl::StrFormat("IntAbsConstraint(%s, %s)", sub_->DebugString(),
5101 target_var_->DebugString());
5102 }
5103
5104 void Accept(ModelVisitor* const visitor) const override {
5105 visitor->BeginVisitConstraint(ModelVisitor::kAbsEqual, this);
5106 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5107 sub_);
5108 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
5109 target_var_);
5110 visitor->EndVisitConstraint(ModelVisitor::kAbsEqual, this);
5111 }
5112
5113 private:
5114 IntVar* const sub_;
5115};
5116
5117class IntAbs : public BaseIntExpr {
5118 public:
5119 IntAbs(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5120
5121 ~IntAbs() override {}
5122
5123 int64 Min() const override {
5124 int64 emin = 0;
5125 int64 emax = 0;
5126 expr_->Range(&emin, &emax);
5127 if (emin >= 0) {
5128 return emin;
5129 }
5130 if (emax <= 0) {
5131 return -emax;
5132 }
5133 return 0;
5134 }
5135
5136 void SetMin(int64 m) override {
5137 if (m > 0) {
5138 int64 emin = 0;
5139 int64 emax = 0;
5140 expr_->Range(&emin, &emax);
5141 if (emin > -m) {
5142 expr_->SetMin(m);
5143 } else if (emax < m) {
5144 expr_->SetMax(-m);
5145 }
5146 }
5147 }
5148
5149 int64 Max() const override {
5150 int64 emin = 0;
5151 int64 emax = 0;
5152 expr_->Range(&emin, &emax);
5153 return std::max(-emin, emax);
5154 }
5155
5156 void SetMax(int64 m) override { expr_->SetRange(-m, m); }
5157
5158 void SetRange(int64 mi, int64 ma) override {
5159 expr_->SetRange(-ma, ma);
5160 if (mi > 0) {
5161 int64 emin = 0;
5162 int64 emax = 0;
5163 expr_->Range(&emin, &emax);
5164 if (emin > -mi) {
5165 expr_->SetMin(mi);
5166 } else if (emax < mi) {
5167 expr_->SetMax(-mi);
5168 }
5169 }
5170 }
5171
5172 void Range(int64* mi, int64* ma) override {
5173 int64 emin = 0;
5174 int64 emax = 0;
5175 expr_->Range(&emin, &emax);
5176 if (emin >= 0) {
5177 *mi = emin;
5178 *ma = emax;
5179 } else if (emax <= 0) {
5180 *mi = -emax;
5181 *ma = -emin;
5182 } else {
5183 *mi = 0;
5184 *ma = std::max(-emin, emax);
5185 }
5186 }
5187
5188 bool Bound() const override { return expr_->Bound(); }
5189
5190 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5191
5192 std::string name() const override {
5193 return absl::StrFormat("IntAbs(%s)", expr_->name());
5194 }
5195
5196 std::string DebugString() const override {
5197 return absl::StrFormat("IntAbs(%s)", expr_->DebugString());
5198 }
5199
5200 void Accept(ModelVisitor* const visitor) const override {
5201 visitor->BeginVisitIntegerExpression(ModelVisitor::kAbs, this);
5202 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5203 expr_);
5204 visitor->EndVisitIntegerExpression(ModelVisitor::kAbs, this);
5205 }
5206
5207 IntVar* CastToVar() override {
5208 int64 min_value = 0;
5209 int64 max_value = 0;
5210 Range(&min_value, &max_value);
5211 Solver* const s = solver();
5212 const std::string name = absl::StrFormat("AbsVar(%s)", expr_->name());
5213 IntVar* const target = s->MakeIntVar(min_value, max_value, name);
5214 CastConstraint* const ct =
5215 s->RevAlloc(new IntAbsConstraint(s, expr_->Var(), target));
5216 s->AddCastConstraint(ct, target, this);
5217 return target;
5218 }
5219
5220 private:
5221 IntExpr* const expr_;
5222};
5223
5224// ----- Square -----
5225
5226// TODO(user): shouldn't we compare to kint32max^2 instead of kint64max?
5227class IntSquare : public BaseIntExpr {
5228 public:
5229 IntSquare(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5230 ~IntSquare() override {}
5231
5232 int64 Min() const override {
5233 const int64 emin = expr_->Min();
5234 if (emin >= 0) {
5235 return emin >= kint32max ? kint64max : emin * emin;
5236 }
5237 const int64 emax = expr_->Max();
5238 if (emax < 0) {
5239 return emax <= -kint32max ? kint64max : emax * emax;
5240 }
5241 return 0LL;
5242 }
5243 void SetMin(int64 m) override {
5244 if (m <= 0) {
5245 return;
5246 }
5247 // TODO(user): What happens if m is kint64max?
5248 const int64 emin = expr_->Min();
5249 const int64 emax = expr_->Max();
5250 const int64 root = static_cast<int64>(ceil(sqrt(static_cast<double>(m))));
5251 if (emin >= 0) {
5252 expr_->SetMin(root);
5253 } else if (emax <= 0) {
5254 expr_->SetMax(-root);
5255 } else if (expr_->IsVar()) {
5256 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5257 }
5258 }
5259 int64 Max() const override {
5260 const int64 emax = expr_->Max();
5261 const int64 emin = expr_->Min();
5262 if (emax >= kint32max || emin <= -kint32max) {
5263 return kint64max;
5264 }
5265 return std::max(emin * emin, emax * emax);
5266 }
5267 void SetMax(int64 m) override {
5268 if (m < 0) {
5269 solver()->Fail();
5270 }
5271 if (m == kint64max) {
5272 return;
5273 }
5274 const int64 root = static_cast<int64>(floor(sqrt(static_cast<double>(m))));
5275 expr_->SetRange(-root, root);
5276 }
5277 bool Bound() const override { return expr_->Bound(); }
5278 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5279 std::string name() const override {
5280 return absl::StrFormat("IntSquare(%s)", expr_->name());
5281 }
5282 std::string DebugString() const override {
5283 return absl::StrFormat("IntSquare(%s)", expr_->DebugString());
5284 }
5285
5286 void Accept(ModelVisitor* const visitor) const override {
5287 visitor->BeginVisitIntegerExpression(ModelVisitor::kSquare, this);
5288 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5289 expr_);
5290 visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this);
5291 }
5292
5293 IntExpr* expr() const { return expr_; }
5294
5295 protected:
5296 IntExpr* const expr_;
5297};
5298
5299class PosIntSquare : public IntSquare {
5300 public:
5301 PosIntSquare(Solver* const s, IntExpr* const e) : IntSquare(s, e) {}
5302 ~PosIntSquare() override {}
5303
5304 int64 Min() const override {
5305 const int64 emin = expr_->Min();
5306 return emin >= kint32max ? kint64max : emin * emin;
5307 }
5308 void SetMin(int64 m) override {
5309 if (m <= 0) {
5310 return;
5311 }
5312 const int64 root = static_cast<int64>(ceil(sqrt(static_cast<double>(m))));
5313 expr_->SetMin(root);
5314 }
5315 int64 Max() const override {
5316 const int64 emax = expr_->Max();
5317 return emax >= kint32max ? kint64max : emax * emax;
5318 }
5319 void SetMax(int64 m) override {
5320 if (m < 0) {
5321 solver()->Fail();
5322 }
5323 if (m == kint64max) {
5324 return;
5325 }
5326 const int64 root = static_cast<int64>(floor(sqrt(static_cast<double>(m))));
5327 expr_->SetMax(root);
5328 }
5329};
5330
5331// ----- EvenPower -----
5332
5333int64 IntPower(int64 value, int64 power) {
5334 int64 result = value;
5335 // TODO(user): Speed that up.
5336 for (int i = 1; i < power; ++i) {
5337 result *= value;
5338 }
5339 return result;
5340}
5341
5342int64 OverflowLimit(int64 power) {
5343 return static_cast<int64>(
5344 floor(exp(log(static_cast<double>(kint64max)) / power)));
5345}
5346
5347class BasePower : public BaseIntExpr {
5348 public:
5349 BasePower(Solver* const s, IntExpr* const e, int64 n)
5350 : BaseIntExpr(s), expr_(e), pow_(n), limit_(OverflowLimit(n)) {
5351 CHECK_GT(n, 0);
5352 }
5353
5354 ~BasePower() override {}
5355
5356 bool Bound() const override { return expr_->Bound(); }
5357
5358 IntExpr* expr() const { return expr_; }
5359
5360 int64 exponant() const { return pow_; }
5361
5362 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5363
5364 std::string name() const override {
5365 return absl::StrFormat("IntPower(%s, %d)", expr_->name(), pow_);
5366 }
5367
5368 std::string DebugString() const override {
5369 return absl::StrFormat("IntPower(%s, %d)", expr_->DebugString(), pow_);
5370 }
5371
5372 void Accept(ModelVisitor* const visitor) const override {
5373 visitor->BeginVisitIntegerExpression(ModelVisitor::kPower, this);
5374 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5375 expr_);
5376 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, pow_);
5377 visitor->EndVisitIntegerExpression(ModelVisitor::kPower, this);
5378 }
5379
5380 protected:
5381 int64 Pown(int64 value) const {
5382 if (value >= limit_) {
5383 return kint64max;
5384 }
5385 if (value <= -limit_) {
5386 if (pow_ % 2 == 0) {
5387 return kint64max;
5388 } else {
5389 return kint64min;
5390 }
5391 }
5392 return IntPower(value, pow_);
5393 }
5394
5395 int64 SqrnDown(int64 value) const {
5396 if (value == kint64min) {
5397 return kint64min;
5398 }
5399 if (value == kint64max) {
5400 return kint64max;
5401 }
5402 int64 res = 0;
5403 const double d_value = static_cast<double>(value);
5404 if (value >= 0) {
5405 const double sq = exp(log(d_value) / pow_);
5406 res = static_cast<int64>(floor(sq));
5407 } else {
5408 CHECK_EQ(1, pow_ % 2);
5409 const double sq = exp(log(-d_value) / pow_);
5410 res = -static_cast<int64>(ceil(sq));
5411 }
5412 const int64 pow_res = Pown(res + 1);
5413 if (pow_res <= value) {
5414 return res + 1;
5415 } else {
5416 return res;
5417 }
5418 }
5419
5420 int64 SqrnUp(int64 value) const {
5421 if (value == kint64min) {
5422 return kint64min;
5423 }
5424 if (value == kint64max) {
5425 return kint64max;
5426 }
5427 int64 res = 0;
5428 const double d_value = static_cast<double>(value);
5429 if (value >= 0) {
5430 const double sq = exp(log(d_value) / pow_);
5431 res = static_cast<int64>(ceil(sq));
5432 } else {
5433 CHECK_EQ(1, pow_ % 2);
5434 const double sq = exp(log(-d_value) / pow_);
5435 res = -static_cast<int64>(floor(sq));
5436 }
5437 const int64 pow_res = Pown(res - 1);
5438 if (pow_res >= value) {
5439 return res - 1;
5440 } else {
5441 return res;
5442 }
5443 }
5444
5445 IntExpr* const expr_;
5448};
5449
5450class IntEvenPower : public BasePower {
5451 public:
5452 IntEvenPower(Solver* const s, IntExpr* const e, int64 n)
5453 : BasePower(s, e, n) {
5454 CHECK_EQ(0, n % 2);
5455 }
5456
5457 ~IntEvenPower() override {}
5458
5459 int64 Min() const override {
5460 int64 emin = 0;
5461 int64 emax = 0;
5462 expr_->Range(&emin, &emax);
5463 if (emin >= 0) {
5464 return Pown(emin);
5465 }
5466 if (emax < 0) {
5467 return Pown(emax);
5468 }
5469 return 0LL;
5470 }
5471 void SetMin(int64 m) override {
5472 if (m <= 0) {
5473 return;
5474 }
5475 int64 emin = 0;
5476 int64 emax = 0;
5477 expr_->Range(&emin, &emax);
5478 const int64 root = SqrnUp(m);
5479 if (emin > -root) {
5480 expr_->SetMin(root);
5481 } else if (emax < root) {
5482 expr_->SetMax(-root);
5483 } else if (expr_->IsVar()) {
5484 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5485 }
5486 }
5487
5488 int64 Max() const override {
5489 return std::max(Pown(expr_->Min()), Pown(expr_->Max()));
5490 }
5491
5492 void SetMax(int64 m) override {
5493 if (m < 0) {
5494 solver()->Fail();
5495 }
5496 if (m == kint64max) {
5497 return;
5498 }
5499 const int64 root = SqrnDown(m);
5500 expr_->SetRange(-root, root);
5501 }
5502};
5503
5504class PosIntEvenPower : public BasePower {
5505 public:
5506 PosIntEvenPower(Solver* const s, IntExpr* const e, int64 pow)
5507 : BasePower(s, e, pow) {
5508 CHECK_EQ(0, pow % 2);
5509 }
5510
5511 ~PosIntEvenPower() override {}
5512
5513 int64 Min() const override { return Pown(expr_->Min()); }
5514
5515 void SetMin(int64 m) override {
5516 if (m <= 0) {
5517 return;
5518 }
5519 expr_->SetMin(SqrnUp(m));
5520 }
5521 int64 Max() const override { return Pown(expr_->Max()); }
5522
5523 void SetMax(int64 m) override {
5524 if (m < 0) {
5525 solver()->Fail();
5526 }
5527 if (m == kint64max) {
5528 return;
5529 }
5530 expr_->SetMax(SqrnDown(m));
5531 }
5532};
5533
5534class IntOddPower : public BasePower {
5535 public:
5536 IntOddPower(Solver* const s, IntExpr* const e, int64 n) : BasePower(s, e, n) {
5537 CHECK_EQ(1, n % 2);
5538 }
5539
5540 ~IntOddPower() override {}
5541
5542 int64 Min() const override { return Pown(expr_->Min()); }
5543
5544 void SetMin(int64 m) override { expr_->SetMin(SqrnUp(m)); }
5545
5546 int64 Max() const override { return Pown(expr_->Max()); }
5547
5548 void SetMax(int64 m) override { expr_->SetMax(SqrnDown(m)); }
5549};
5550
5551// ----- Min(expr, expr) -----
5552
5553class MinIntExpr : public BaseIntExpr {
5554 public:
5555 MinIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5556 : BaseIntExpr(s), left_(l), right_(r) {}
5557 ~MinIntExpr() override {}
5558 int64 Min() const override {
5559 const int64 lmin = left_->Min();
5560 const int64 rmin = right_->Min();
5561 return std::min(lmin, rmin);
5562 }
5563 void SetMin(int64 m) override {
5564 left_->SetMin(m);
5565 right_->SetMin(m);
5566 }
5567 int64 Max() const override {
5568 const int64 lmax = left_->Max();
5569 const int64 rmax = right_->Max();
5570 return std::min(lmax, rmax);
5571 }
5572 void SetMax(int64 m) override {
5573 if (left_->Min() > m) {
5574 right_->SetMax(m);
5575 }
5576 if (right_->Min() > m) {
5577 left_->SetMax(m);
5578 }
5579 }
5580 std::string name() const override {
5581 return absl::StrFormat("MinIntExpr(%s, %s)", left_->name(), right_->name());
5582 }
5583 std::string DebugString() const override {
5584 return absl::StrFormat("MinIntExpr(%s, %s)", left_->DebugString(),
5585 right_->DebugString());
5586 }
5587 void WhenRange(Demon* d) override {
5588 left_->WhenRange(d);
5589 right_->WhenRange(d);
5590 }
5591
5592 void Accept(ModelVisitor* const visitor) const override {
5593 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5594 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5595 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5596 right_);
5597 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5598 }
5599
5600 private:
5601 IntExpr* const left_;
5602 IntExpr* const right_;
5603};
5604
5605// ----- Min(expr, constant) -----
5606
5607class MinCstIntExpr : public BaseIntExpr {
5608 public:
5609 MinCstIntExpr(Solver* const s, IntExpr* const e, int64 v)
5610 : BaseIntExpr(s), expr_(e), value_(v) {}
5611
5612 ~MinCstIntExpr() override {}
5613
5614 int64 Min() const override { return std::min(expr_->Min(), value_); }
5615
5616 void SetMin(int64 m) override {
5617 if (m > value_) {
5618 solver()->Fail();
5619 }
5620 expr_->SetMin(m);
5621 }
5622
5623 int64 Max() const override { return std::min(expr_->Max(), value_); }
5624
5625 void SetMax(int64 m) override {
5626 if (value_ > m) {
5627 expr_->SetMax(m);
5628 }
5629 }
5630
5631 bool Bound() const override {
5632 return (expr_->Bound() || expr_->Min() >= value_);
5633 }
5634
5635 std::string name() const override {
5636 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->name(), value_);
5637 }
5638
5639 std::string DebugString() const override {
5640 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->DebugString(),
5641 value_);
5642 }
5643
5644 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5645
5646 void Accept(ModelVisitor* const visitor) const override {
5647 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5648 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5649 expr_);
5650 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5651 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5652 }
5653
5654 private:
5655 IntExpr* const expr_;
5656 const int64 value_;
5657};
5658
5659// ----- Max(expr, expr) -----
5660
5661class MaxIntExpr : public BaseIntExpr {
5662 public:
5663 MaxIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5664 : BaseIntExpr(s), left_(l), right_(r) {}
5665
5666 ~MaxIntExpr() override {}
5667
5668 int64 Min() const override { return std::max(left_->Min(), right_->Min()); }
5669
5670 void SetMin(int64 m) override {
5671 if (left_->Max() < m) {
5672 right_->SetMin(m);
5673 } else {
5674 if (right_->Max() < m) {
5675 left_->SetMin(m);
5676 }
5677 }
5678 }
5679
5680 int64 Max() const override { return std::max(left_->Max(), right_->Max()); }
5681
5682 void SetMax(int64 m) override {
5683 left_->SetMax(m);
5684 right_->SetMax(m);
5685 }
5686
5687 std::string name() const override {
5688 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->name(), right_->name());
5689 }
5690
5691 std::string DebugString() const override {
5692 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->DebugString(),
5693 right_->DebugString());
5694 }
5695
5696 void WhenRange(Demon* d) override {
5697 left_->WhenRange(d);
5698 right_->WhenRange(d);
5699 }
5700
5701 void Accept(ModelVisitor* const visitor) const override {
5702 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5703 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5704 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5705 right_);
5706 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5707 }
5708
5709 private:
5710 IntExpr* const left_;
5711 IntExpr* const right_;
5712};
5713
5714// ----- Max(expr, constant) -----
5715
5716class MaxCstIntExpr : public BaseIntExpr {
5717 public:
5718 MaxCstIntExpr(Solver* const s, IntExpr* const e, int64 v)
5719 : BaseIntExpr(s), expr_(e), value_(v) {}
5720
5721 ~MaxCstIntExpr() override {}
5722
5723 int64 Min() const override { return std::max(expr_->Min(), value_); }
5724
5725 void SetMin(int64 m) override {
5726 if (value_ < m) {
5727 expr_->SetMin(m);
5728 }
5729 }
5730
5731 int64 Max() const override { return std::max(expr_->Max(), value_); }
5732
5733 void SetMax(int64 m) override {
5734 if (m < value_) {
5735 solver()->Fail();
5736 }
5737 expr_->SetMax(m);
5738 }
5739
5740 bool Bound() const override {
5741 return (expr_->Bound() || expr_->Max() <= value_);
5742 }
5743
5744 std::string name() const override {
5745 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->name(), value_);
5746 }
5747
5748 std::string DebugString() const override {
5749 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->DebugString(),
5750 value_);
5751 }
5752
5753 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5754
5755 void Accept(ModelVisitor* const visitor) const override {
5756 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5757 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5758 expr_);
5759 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5760 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5761 }
5762
5763 private:
5764 IntExpr* const expr_;
5765 const int64 value_;
5766};
5767
5768// ----- Convex Piecewise -----
5769
5770// This class is a very simple convex piecewise linear function. The
5771// argument of the function is the expression. Between early_date and
5772// late_date, the value of the function is 0. Before early date, it
5773// is affine and the cost is early_cost * (early_date - x). After
5774// late_date, the cost is late_cost * (x - late_date).
5775
5776class SimpleConvexPiecewiseExpr : public BaseIntExpr {
5777 public:
5778 SimpleConvexPiecewiseExpr(Solver* const s, IntExpr* const e, int64 ec,
5779 int64 ed, int64 ld, int64 lc)
5780 : BaseIntExpr(s),
5781 expr_(e),
5782 early_cost_(ec),
5783 early_date_(ec == 0 ? kint64min : ed),
5784 late_date_(lc == 0 ? kint64max : ld),
5785 late_cost_(lc) {
5786 DCHECK_GE(ec, int64{0});
5787 DCHECK_GE(lc, int64{0});
5788 DCHECK_GE(ld, ed);
5789
5790 // If the penalty is 0, we can push the "confort zone or zone
5791 // of no cost towards infinity.
5792 }
5793
5794 ~SimpleConvexPiecewiseExpr() override {}
5795
5796 int64 Min() const override {
5797 const int64 vmin = expr_->Min();
5798 const int64 vmax = expr_->Max();
5799 if (vmin >= late_date_) {
5800 return (vmin - late_date_) * late_cost_;
5801 } else if (vmax <= early_date_) {
5802 return (early_date_ - vmax) * early_cost_;
5803 } else {
5804 return 0LL;
5805 }
5806 }
5807
5808 void SetMin(int64 m) override {
5809 if (m <= 0) {
5810 return;
5811 }
5812 int64 vmin = 0;
5813 int64 vmax = 0;
5814 expr_->Range(&vmin, &vmax);
5815
5816 const int64 rb =
5817 (late_cost_ == 0 ? vmax : late_date_ + PosIntDivUp(m, late_cost_) - 1);
5818 const int64 lb =
5819 (early_cost_ == 0 ? vmin
5820 : early_date_ - PosIntDivUp(m, early_cost_) + 1);
5821
5822 if (expr_->IsVar()) {
5823 expr_->Var()->RemoveInterval(lb, rb);
5824 }
5825 }
5826
5827 int64 Max() const override {
5828 const int64 vmin = expr_->Min();
5829 const int64 vmax = expr_->Max();
5830 const int64 mr = vmax > late_date_ ? (vmax - late_date_) * late_cost_ : 0;
5831 const int64 ml =
5832 vmin < early_date_ ? (early_date_ - vmin) * early_cost_ : 0;
5833 return std::max(mr, ml);
5834 }
5835
5836 void SetMax(int64 m) override {
5837 if (m < 0) {
5838 solver()->Fail();
5839 }
5840 if (late_cost_ != 0LL) {
5841 const int64 rb = late_date_ + PosIntDivDown(m, late_cost_);
5842 if (early_cost_ != 0LL) {
5843 const int64 lb = early_date_ - PosIntDivDown(m, early_cost_);
5844 expr_->SetRange(lb, rb);
5845 } else {
5846 expr_->SetMax(rb);
5847 }
5848 } else {
5849 if (early_cost_ != 0LL) {
5850 const int64 lb = early_date_ - PosIntDivDown(m, early_cost_);
5851 expr_->SetMin(lb);
5852 }
5853 }
5854 }
5855
5856 std::string name() const override {
5857 return absl::StrFormat(
5858 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5859 expr_->name(), early_cost_, early_date_, late_date_, late_cost_);
5860 }
5861
5862 std::string DebugString() const override {
5863 return absl::StrFormat(
5864 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5865 expr_->DebugString(), early_cost_, early_date_, late_date_, late_cost_);
5866 }
5867
5868 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5869
5870 void Accept(ModelVisitor* const visitor) const override {
5871 visitor->BeginVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5872 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5873 expr_);
5874 visitor->VisitIntegerArgument(ModelVisitor::kEarlyCostArgument,
5875 early_cost_);
5876 visitor->VisitIntegerArgument(ModelVisitor::kEarlyDateArgument,
5877 early_date_);
5878 visitor->VisitIntegerArgument(ModelVisitor::kLateCostArgument, late_cost_);
5879 visitor->VisitIntegerArgument(ModelVisitor::kLateDateArgument, late_date_);
5880 visitor->EndVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5881 }
5882
5883 private:
5884 IntExpr* const expr_;
5885 const int64 early_cost_;
5886 const int64 early_date_;
5887 const int64 late_date_;
5888 const int64 late_cost_;
5889};
5890
5891// ----- Semi Continuous -----
5892
5893class SemiContinuousExpr : public BaseIntExpr {
5894 public:
5895 SemiContinuousExpr(Solver* const s, IntExpr* const e, int64 fixed_charge,
5896 int64 step)
5897 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge), step_(step) {
5898 DCHECK_GE(fixed_charge, int64{0});
5899 DCHECK_GT(step, int64{0});
5900 }
5901
5902 ~SemiContinuousExpr() override {}
5903
5904 int64 Value(int64 x) const {
5905 if (x <= 0) {
5906 return 0;
5907 } else {
5908 return CapAdd(fixed_charge_, CapProd(x, step_));
5909 }
5910 }
5911
5912 int64 Min() const override { return Value(expr_->Min()); }
5913
5914 void SetMin(int64 m) override {
5915 if (m >= CapAdd(fixed_charge_, step_)) {
5916 const int64 y = PosIntDivUp(CapSub(m, fixed_charge_), step_);
5917 expr_->SetMin(y);
5918 } else if (m > 0) {
5919 expr_->SetMin(1);
5920 }
5921 }
5922
5923 int64 Max() const override { return Value(expr_->Max()); }
5924
5925 void SetMax(int64 m) override {
5926 if (m < 0) {
5927 solver()->Fail();
5928 }
5929 if (m == kint64max) {
5930 return;
5931 }
5932 if (m < CapAdd(fixed_charge_, step_)) {
5933 expr_->SetMax(0);
5934 } else {
5935 const int64 y = PosIntDivDown(CapSub(m, fixed_charge_), step_);
5936 expr_->SetMax(y);
5937 }
5938 }
5939
5940 std::string name() const override {
5941 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5942 expr_->name(), fixed_charge_, step_);
5943 }
5944
5945 std::string DebugString() const override {
5946 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5947 expr_->DebugString(), fixed_charge_, step_);
5948 }
5949
5950 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5951
5952 void Accept(ModelVisitor* const visitor) const override {
5953 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
5954 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5955 expr_);
5956 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
5957 fixed_charge_);
5958 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, step_);
5959 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
5960 }
5961
5962 private:
5963 IntExpr* const expr_;
5964 const int64 fixed_charge_;
5965 const int64 step_;
5966};
5967
5968class SemiContinuousStepOneExpr : public BaseIntExpr {
5969 public:
5970 SemiContinuousStepOneExpr(Solver* const s, IntExpr* const e,
5971 int64 fixed_charge)
5972 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
5973 DCHECK_GE(fixed_charge, int64{0});
5974 }
5975
5976 ~SemiContinuousStepOneExpr() override {}
5977
5978 int64 Value(int64 x) const {
5979 if (x <= 0) {
5980 return 0;
5981 } else {
5982 return fixed_charge_ + x;
5983 }
5984 }
5985
5986 int64 Min() const override { return Value(expr_->Min()); }
5987
5988 void SetMin(int64 m) override {
5989 if (m >= fixed_charge_ + 1) {
5990 expr_->SetMin(m - fixed_charge_);
5991 } else if (m > 0) {
5992 expr_->SetMin(1);
5993 }
5994 }
5995
5996 int64 Max() const override { return Value(expr_->Max()); }
5997
5998 void SetMax(int64 m) override {
5999 if (m < 0) {
6000 solver()->Fail();
6001 }
6002 if (m < fixed_charge_ + 1) {
6003 expr_->SetMax(0);
6004 } else {
6005 expr_->SetMax(m - fixed_charge_);
6006 }
6007 }
6008
6009 std::string name() const override {
6010 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6011 expr_->name(), fixed_charge_);
6012 }
6013
6014 std::string DebugString() const override {
6015 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6016 expr_->DebugString(), fixed_charge_);
6017 }
6018
6019 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6020
6021 void Accept(ModelVisitor* const visitor) const override {
6022 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6023 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6024 expr_);
6025 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6026 fixed_charge_);
6027 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 1);
6028 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6029 }
6030
6031 private:
6032 IntExpr* const expr_;
6033 const int64 fixed_charge_;
6034};
6035
6036class SemiContinuousStepZeroExpr : public BaseIntExpr {
6037 public:
6038 SemiContinuousStepZeroExpr(Solver* const s, IntExpr* const e,
6039 int64 fixed_charge)
6040 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6041 DCHECK_GT(fixed_charge, int64{0});
6042 }
6043
6044 ~SemiContinuousStepZeroExpr() override {}
6045
6046 int64 Value(int64 x) const {
6047 if (x <= 0) {
6048 return 0;
6049 } else {
6050 return fixed_charge_;
6051 }
6052 }
6053
6054 int64 Min() const override { return Value(expr_->Min()); }
6055
6056 void SetMin(int64 m) override {
6057 if (m >= fixed_charge_) {
6058 solver()->Fail();
6059 } else if (m > 0) {
6060 expr_->SetMin(1);
6061 }
6062 }
6063
6064 int64 Max() const override { return Value(expr_->Max()); }
6065
6066 void SetMax(int64 m) override {
6067 if (m < 0) {
6068 solver()->Fail();
6069 }
6070 if (m < fixed_charge_) {
6071 expr_->SetMax(0);
6072 }
6073 }
6074
6075 std::string name() const override {
6076 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6077 expr_->name(), fixed_charge_);
6078 }
6079
6080 std::string DebugString() const override {
6081 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6082 expr_->DebugString(), fixed_charge_);
6083 }
6084
6085 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6086
6087 void Accept(ModelVisitor* const visitor) const override {
6088 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6089 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6090 expr_);
6091 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6092 fixed_charge_);
6093 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 0);
6094 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6095 }
6096
6097 private:
6098 IntExpr* const expr_;
6099 const int64 fixed_charge_;
6100};
6101
6102// This constraints links an expression and the variable it is casted into
6103class LinkExprAndVar : public CastConstraint {
6104 public:
6105 LinkExprAndVar(Solver* const s, IntExpr* const expr, IntVar* const var)
6106 : CastConstraint(s, var), expr_(expr) {}
6107
6108 ~LinkExprAndVar() override {}
6109
6110 void Post() override {
6111 Solver* const s = solver();
6112 Demon* d = s->MakeConstraintInitialPropagateCallback(this);
6113 expr_->WhenRange(d);
6114 target_var_->WhenRange(d);
6115 }
6116
6117 void InitialPropagate() override {
6118 expr_->SetRange(target_var_->Min(), target_var_->Max());
6119 int64 l, u;
6120 expr_->Range(&l, &u);
6121 target_var_->SetRange(l, u);
6122 }
6123
6124 std::string DebugString() const override {
6125 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6126 target_var_->DebugString());
6127 }
6128
6129 void Accept(ModelVisitor* const visitor) const override {
6130 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6131 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6132 expr_);
6133 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6134 target_var_);
6135 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6136 }
6137
6138 private:
6139 IntExpr* const expr_;
6140};
6141
6142// ----- Conditional Expression -----
6143
6144class ExprWithEscapeValue : public BaseIntExpr {
6145 public:
6146 ExprWithEscapeValue(Solver* const s, IntVar* const c, IntExpr* const e,
6147 int64 unperformed_value)
6148 : BaseIntExpr(s),
6149 condition_(c),
6150 expression_(e),
6151 unperformed_value_(unperformed_value) {}
6152
6153 ~ExprWithEscapeValue() override {}
6154
6155 int64 Min() const override {
6156 if (condition_->Min() == 1) {
6157 return expression_->Min();
6158 } else if (condition_->Max() == 1) {
6159 return std::min(unperformed_value_, expression_->Min());
6160 } else {
6161 return unperformed_value_;
6162 }
6163 }
6164
6165 void SetMin(int64 m) override {
6166 if (m > unperformed_value_) {
6167 condition_->SetValue(1);
6168 expression_->SetMin(m);
6169 } else if (condition_->Min() == 1) {
6170 expression_->SetMin(m);
6171 } else if (m > expression_->Max()) {
6172 condition_->SetValue(0);
6173 }
6174 }
6175
6176 int64 Max() const override {
6177 if (condition_->Min() == 1) {
6178 return expression_->Max();
6179 } else if (condition_->Max() == 1) {
6180 return std::max(unperformed_value_, expression_->Max());
6181 } else {
6182 return unperformed_value_;
6183 }
6184 }
6185
6186 void SetMax(int64 m) override {
6187 if (m < unperformed_value_) {
6188 condition_->SetValue(1);
6189 expression_->SetMax(m);
6190 } else if (condition_->Min() == 1) {
6191 expression_->SetMax(m);
6192 } else if (m < expression_->Min()) {
6193 condition_->SetValue(0);
6194 }
6195 }
6196
6197 void SetRange(int64 mi, int64 ma) override {
6198 if (ma < unperformed_value_ || mi > unperformed_value_) {
6199 condition_->SetValue(1);
6200 expression_->SetRange(mi, ma);
6201 } else if (condition_->Min() == 1) {
6202 expression_->SetRange(mi, ma);
6203 } else if (ma < expression_->Min() || mi > expression_->Max()) {
6204 condition_->SetValue(0);
6205 }
6206 }
6207
6208 void SetValue(int64 v) override {
6209 if (v != unperformed_value_) {
6210 condition_->SetValue(1);
6211 expression_->SetValue(v);
6212 } else if (condition_->Min() == 1) {
6213 expression_->SetValue(v);
6214 } else if (v < expression_->Min() || v > expression_->Max()) {
6215 condition_->SetValue(0);
6216 }
6217 }
6218
6219 bool Bound() const override {
6220 return condition_->Max() == 0 || expression_->Bound();
6221 }
6222
6223 void WhenRange(Demon* d) override {
6224 expression_->WhenRange(d);
6225 condition_->WhenBound(d);
6226 }
6227
6228 std::string DebugString() const override {
6229 return absl::StrFormat("ConditionExpr(%s, %s, %d)",
6230 condition_->DebugString(),
6231 expression_->DebugString(), unperformed_value_);
6232 }
6233
6234 void Accept(ModelVisitor* const visitor) const override {
6235 visitor->BeginVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6236 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
6237 condition_);
6238 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6239 expression_);
6240 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
6241 unperformed_value_);
6242 visitor->EndVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6243 }
6244
6245 private:
6246 IntVar* const condition_;
6247 IntExpr* const expression_;
6248 const int64 unperformed_value_;
6249 DISALLOW_COPY_AND_ASSIGN(ExprWithEscapeValue);
6250};
6251
6252// ----- This is a specialized case when the variable exact type is known -----
6253class LinkExprAndDomainIntVar : public CastConstraint {
6254 public:
6255 LinkExprAndDomainIntVar(Solver* const s, IntExpr* const expr,
6256 DomainIntVar* const var)
6257 : CastConstraint(s, var),
6258 expr_(expr),
6259 cached_min_(kint64min),
6260 cached_max_(kint64max),
6261 fail_stamp_(uint64_t{0}) {}
6262
6263 ~LinkExprAndDomainIntVar() override {}
6264
6265 DomainIntVar* var() const {
6266 return reinterpret_cast<DomainIntVar*>(target_var_);
6267 }
6268
6269 void Post() override {
6270 Solver* const s = solver();
6271 Demon* const d = s->MakeConstraintInitialPropagateCallback(this);
6272 expr_->WhenRange(d);
6273 Demon* const target_var_demon = MakeConstraintDemon0(
6274 solver(), this, &LinkExprAndDomainIntVar::Propagate, "Propagate");
6275 target_var_->WhenRange(target_var_demon);
6276 }
6277
6278 void InitialPropagate() override {
6279 expr_->SetRange(var()->min_.Value(), var()->max_.Value());
6280 expr_->Range(&cached_min_, &cached_max_);
6281 var()->DomainIntVar::SetRange(cached_min_, cached_max_);
6282 }
6283
6284 void Propagate() {
6285 if (var()->min_.Value() > cached_min_ ||
6286 var()->max_.Value() < cached_max_ ||
6287 solver()->fail_stamp() != fail_stamp_) {
6288 InitialPropagate();
6289 fail_stamp_ = solver()->fail_stamp();
6290 }
6291 }
6292
6293 std::string DebugString() const override {
6294 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6295 target_var_->DebugString());
6296 }
6297
6298 void Accept(ModelVisitor* const visitor) const override {
6299 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6300 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6301 expr_);
6302 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6303 target_var_);
6304 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6305 }
6306
6307 private:
6308 IntExpr* const expr_;
6309 int64 cached_min_;
6310 int64 cached_max_;
6311 uint64 fail_stamp_;
6312};
6313} // namespace
6314
6315// ----- Misc -----
6316
6317IntVarIterator* BooleanVar::MakeHoleIterator(bool reversible) const {
6318 return COND_REV_ALLOC(reversible, new EmptyIterator());
6319}
6320IntVarIterator* BooleanVar::MakeDomainIterator(bool reversible) const {
6321 return COND_REV_ALLOC(reversible, new RangeIterator(this));
6322}
6323
6324// ----- API -----
6325
6327 DCHECK_EQ(DOMAIN_INT_VAR, var->VarType());
6328 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6329 dvar->CleanInProcess();
6330}
6331
6332Constraint* SetIsEqual(IntVar* const var, const std::vector<int64>& values,
6333 const std::vector<IntVar*>& vars) {
6334 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6335 CHECK(dvar != nullptr);
6336 return dvar->SetIsEqual(values, vars);
6337}
6338
6340 const std::vector<int64>& values,
6341 const std::vector<IntVar*>& vars) {
6342 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6343 CHECK(dvar != nullptr);
6344 return dvar->SetIsGreaterOrEqual(values, vars);
6345}
6346
6348 DCHECK_EQ(BOOLEAN_VAR, var->VarType());
6349 BooleanVar* const boolean_var = reinterpret_cast<BooleanVar*>(var);
6350 boolean_var->RestoreValue();
6351}
6352
6353// ----- API -----
6354
6355IntVar* Solver::MakeIntVar(int64 min, int64 max, const std::string& name) {
6356 if (min == max) {
6357 return MakeIntConst(min, name);
6358 }
6359 if (min == 0 && max == 1) {
6360 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6361 } else if (CapSub(max, min) == 1) {
6362 const std::string inner_name = "inner_" + name;
6363 return RegisterIntVar(
6364 MakeSum(RevAlloc(new ConcreteBooleanVar(this, inner_name)), min)
6365 ->VarWithName(name));
6366 } else {
6367 return RegisterIntVar(RevAlloc(new DomainIntVar(this, min, max, name)));
6368 }
6369}
6370
6371IntVar* Solver::MakeIntVar(int64 min, int64 max) {
6372 return MakeIntVar(min, max, "");
6373}
6374
6375IntVar* Solver::MakeBoolVar(const std::string& name) {
6376 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6377}
6378
6379IntVar* Solver::MakeBoolVar() {
6380 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, "")));
6381}
6382
6383IntVar* Solver::MakeIntVar(const std::vector<int64>& values,
6384 const std::string& name) {
6385 DCHECK(!values.empty());
6386 // Fast-track the case where we have a single value.
6387 if (values.size() == 1) return MakeIntConst(values[0], name);
6388 // Sort and remove duplicates.
6389 std::vector<int64> unique_sorted_values = values;
6390 gtl::STLSortAndRemoveDuplicates(&unique_sorted_values);
6391 // Case when we have a single value, after clean-up.
6392 if (unique_sorted_values.size() == 1) return MakeIntConst(values[0], name);
6393 // Case when the values are a dense interval of integers.
6394 if (unique_sorted_values.size() ==
6395 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6396 return MakeIntVar(unique_sorted_values.front(), unique_sorted_values.back(),
6397 name);
6398 }
6399 // Compute the GCD: if it's not 1, we can express the variable's domain as
6400 // the product of the GCD and of a domain with smaller values.
6401 int64 gcd = 0;
6402 for (const int64 v : unique_sorted_values) {
6403 if (gcd == 0) {
6404 gcd = std::abs(v);
6405 } else {
6406 gcd = MathUtil::GCD64(gcd, std::abs(v)); // Supports v==0.
6407 }
6408 if (gcd == 1) {
6409 // If it's 1, though, we can't do anything special, so we
6410 // immediately return a new DomainIntVar.
6411 return RegisterIntVar(
6412 RevAlloc(new DomainIntVar(this, unique_sorted_values, name)));
6413 }
6414 }
6415 DCHECK_GT(gcd, 1);
6416 for (int64& v : unique_sorted_values) {
6417 DCHECK_EQ(0, v % gcd);
6418 v /= gcd;
6419 }
6420 const std::string new_name = name.empty() ? "" : "inner_" + name;
6421 // Catch the case where the divided values are a dense set of integers.
6422 IntVar* inner_intvar = nullptr;
6423 if (unique_sorted_values.size() ==
6424 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6425 inner_intvar = MakeIntVar(unique_sorted_values.front(),
6426 unique_sorted_values.back(), new_name);
6427 } else {
6428 inner_intvar = RegisterIntVar(
6429 RevAlloc(new DomainIntVar(this, unique_sorted_values, new_name)));
6430 }
6431 return MakeProd(inner_intvar, gcd)->Var();
6432}
6433
6434IntVar* Solver::MakeIntVar(const std::vector<int64>& values) {
6435 return MakeIntVar(values, "");
6436}
6437
6438IntVar* Solver::MakeIntVar(const std::vector<int>& values,
6439 const std::string& name) {
6440 return MakeIntVar(ToInt64Vector(values), name);
6441}
6442
6443IntVar* Solver::MakeIntVar(const std::vector<int>& values) {
6444 return MakeIntVar(values, "");
6445}
6446
6447IntVar* Solver::MakeIntConst(int64 val, const std::string& name) {
6448 // If IntConst is going to be named after its creation,
6449 // cp_share_int_consts should be set to false otherwise names can potentially
6450 // be overwritten.
6451 if (absl::GetFlag(FLAGS_cp_share_int_consts) && name.empty() &&
6452 val >= MIN_CACHED_INT_CONST && val <= MAX_CACHED_INT_CONST) {
6453 return cached_constants_[val - MIN_CACHED_INT_CONST];
6454 }
6455 return RevAlloc(new IntConst(this, val, name));
6456}
6457
6458IntVar* Solver::MakeIntConst(int64 val) { return MakeIntConst(val, ""); }
6459
6460// ----- Int Var and associated methods -----
6461
6462namespace {
6463std::string IndexedName(const std::string& prefix, int index, int max_index) {
6464#if 0
6465#if defined(_MSC_VER)
6466 const int digits = max_index > 0 ?
6467 static_cast<int>(log(1.0L * max_index) / log(10.0L)) + 1 :
6468 1;
6469#else
6470 const int digits = max_index > 0 ? static_cast<int>(log10(max_index)) + 1: 1;
6471#endif
6472 return absl::StrFormat("%s%0*d", prefix, digits, index);
6473#else
6474 return absl::StrCat(prefix, index);
6475#endif
6476}
6477} // namespace
6478
6479void Solver::MakeIntVarArray(int var_count, int64 vmin, int64 vmax,
6480 const std::string& name,
6481 std::vector<IntVar*>* vars) {
6482 for (int i = 0; i < var_count; ++i) {
6483 vars->push_back(MakeIntVar(vmin, vmax, IndexedName(name, i, var_count)));
6484 }
6485}
6486
6487void Solver::MakeIntVarArray(int var_count, int64 vmin, int64 vmax,
6488 std::vector<IntVar*>* vars) {
6489 for (int i = 0; i < var_count; ++i) {
6490 vars->push_back(MakeIntVar(vmin, vmax));
6491 }
6492}
6493
6494IntVar** Solver::MakeIntVarArray(int var_count, int64 vmin, int64 vmax,
6495 const std::string& name) {
6496 IntVar** vars = new IntVar*[var_count];
6497 for (int i = 0; i < var_count; ++i) {
6498 vars[i] = MakeIntVar(vmin, vmax, IndexedName(name, i, var_count));
6499 }
6500 return vars;
6501}
6502
6503void Solver::MakeBoolVarArray(int var_count, const std::string& name,
6504 std::vector<IntVar*>* vars) {
6505 for (int i = 0; i < var_count; ++i) {
6506 vars->push_back(MakeBoolVar(IndexedName(name, i, var_count)));
6507 }
6508}
6509
6510void Solver::MakeBoolVarArray(int var_count, std::vector<IntVar*>* vars) {
6511 for (int i = 0; i < var_count; ++i) {
6512 vars->push_back(MakeBoolVar());
6513 }
6514}
6515
6516IntVar** Solver::MakeBoolVarArray(int var_count, const std::string& name) {
6517 IntVar** vars = new IntVar*[var_count];
6518 for (int i = 0; i < var_count; ++i) {
6519 vars[i] = MakeBoolVar(IndexedName(name, i, var_count));
6520 }
6521 return vars;
6522}
6523
6524void Solver::InitCachedIntConstants() {
6525 for (int i = MIN_CACHED_INT_CONST; i <= MAX_CACHED_INT_CONST; ++i) {
6526 cached_constants_[i - MIN_CACHED_INT_CONST] =
6527 RevAlloc(new IntConst(this, i, "")); // note the empty name
6528 }
6529}
6530
6531IntExpr* Solver::MakeSum(IntExpr* const left, IntExpr* const right) {
6532 CHECK_EQ(this, left->solver());
6533 CHECK_EQ(this, right->solver());
6534 if (right->Bound()) {
6535 return MakeSum(left, right->Min());
6536 }
6537 if (left->Bound()) {
6538 return MakeSum(right, left->Min());
6539 }
6540 if (left == right) {
6541 return MakeProd(left, 2);
6542 }
6543 IntExpr* cache = model_cache_->FindExprExprExpression(
6544 left, right, ModelCache::EXPR_EXPR_SUM);
6545 if (cache == nullptr) {
6546 cache = model_cache_->FindExprExprExpression(right, left,
6547 ModelCache::EXPR_EXPR_SUM);
6548 }
6549 if (cache != nullptr) {
6550 return cache;
6551 } else {
6552 IntExpr* const result =
6553 AddOverflows(left->Max(), right->Max()) ||
6554 AddOverflows(left->Min(), right->Min())
6555 ? RegisterIntExpr(RevAlloc(new SafePlusIntExpr(this, left, right)))
6556 : RegisterIntExpr(RevAlloc(new PlusIntExpr(this, left, right)));
6557 model_cache_->InsertExprExprExpression(result, left, right,
6558 ModelCache::EXPR_EXPR_SUM);
6559 return result;
6560 }
6561}
6562
6563IntExpr* Solver::MakeSum(IntExpr* const expr, int64 value) {
6564 CHECK_EQ(this, expr->solver());
6565 if (expr->Bound()) {
6566 return MakeIntConst(expr->Min() + value);
6567 }
6568 if (value == 0) {
6569 return expr;
6570 }
6571 IntExpr* result = Cache()->FindExprConstantExpression(
6572 expr, value, ModelCache::EXPR_CONSTANT_SUM);
6573 if (result == nullptr) {
6574 if (expr->IsVar() && !AddOverflows(value, expr->Max()) &&
6575 !AddOverflows(value, expr->Min())) {
6576 IntVar* const var = expr->Var();
6577 switch (var->VarType()) {
6578 case DOMAIN_INT_VAR: {
6579 result = RegisterIntExpr(RevAlloc(new PlusCstDomainIntVar(
6580 this, reinterpret_cast<DomainIntVar*>(var), value)));
6581 break;
6582 }
6583 case CONST_VAR: {
6584 result = RegisterIntExpr(MakeIntConst(var->Min() + value));
6585 break;
6586 }
6587 case VAR_ADD_CST: {
6588 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6589 IntVar* const sub_var = add_var->SubVar();
6590 const int64 new_constant = value + add_var->Constant();
6591 if (new_constant == 0) {
6592 result = sub_var;
6593 } else {
6594 if (sub_var->VarType() == DOMAIN_INT_VAR) {
6595 DomainIntVar* const dvar =
6596 reinterpret_cast<DomainIntVar*>(sub_var);
6597 result = RegisterIntExpr(
6598 RevAlloc(new PlusCstDomainIntVar(this, dvar, new_constant)));
6599 } else {
6600 result = RegisterIntExpr(
6601 RevAlloc(new PlusCstIntVar(this, sub_var, new_constant)));
6602 }
6603 }
6604 break;
6605 }
6606 case CST_SUB_VAR: {
6607 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6608 IntVar* const sub_var = add_var->SubVar();
6609 const int64 new_constant = value + add_var->Constant();
6610 result = RegisterIntExpr(
6611 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6612 break;
6613 }
6614 case OPP_VAR: {
6615 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6616 IntVar* const sub_var = add_var->SubVar();
6617 result =
6618 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, sub_var, value)));
6619 break;
6620 }
6621 default:
6622 result =
6623 RegisterIntExpr(RevAlloc(new PlusCstIntVar(this, var, value)));
6624 }
6625 } else {
6626 result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, expr, value)));
6627 }
6628 Cache()->InsertExprConstantExpression(result, expr, value,
6629 ModelCache::EXPR_CONSTANT_SUM);
6630 }
6631 return result;
6632}
6633
6634IntExpr* Solver::MakeDifference(IntExpr* const left, IntExpr* const right) {
6635 CHECK_EQ(this, left->solver());
6636 CHECK_EQ(this, right->solver());
6637 if (left->Bound()) {
6638 return MakeDifference(left->Min(), right);
6639 }
6640 if (right->Bound()) {
6641 return MakeSum(left, -right->Min());
6642 }
6643 IntExpr* sub_left = nullptr;
6644 IntExpr* sub_right = nullptr;
6645 int64 left_coef = 1;
6646 int64 right_coef = 1;
6647 if (IsProduct(left, &sub_left, &left_coef) &&
6648 IsProduct(right, &sub_right, &right_coef)) {
6649 const int64 abs_gcd =
6650 MathUtil::GCD64(std::abs(left_coef), std::abs(right_coef));
6651 if (abs_gcd != 0 && abs_gcd != 1) {
6652 return MakeProd(MakeDifference(MakeProd(sub_left, left_coef / abs_gcd),
6653 MakeProd(sub_right, right_coef / abs_gcd)),
6654 abs_gcd);
6655 }
6656 }
6657
6658 IntExpr* result = Cache()->FindExprExprExpression(
6659 left, right, ModelCache::EXPR_EXPR_DIFFERENCE);
6660 if (result == nullptr) {
6661 if (!SubOverflows(left->Min(), right->Max()) &&
6662 !SubOverflows(left->Max(), right->Min())) {
6663 result = RegisterIntExpr(RevAlloc(new SubIntExpr(this, left, right)));
6664 } else {
6665 result = RegisterIntExpr(RevAlloc(new SafeSubIntExpr(this, left, right)));
6666 }
6667 Cache()->InsertExprExprExpression(result, left, right,
6668 ModelCache::EXPR_EXPR_DIFFERENCE);
6669 }
6670 return result;
6671}
6672
6673// warning: this is 'value - expr'.
6674IntExpr* Solver::MakeDifference(int64 value, IntExpr* const expr) {
6675 CHECK_EQ(this, expr->solver());
6676 if (expr->Bound()) {
6677 return MakeIntConst(value - expr->Min());
6678 }
6679 if (value == 0) {
6680 return MakeOpposite(expr);
6681 }
6682 IntExpr* result = Cache()->FindExprConstantExpression(
6683 expr, value, ModelCache::EXPR_CONSTANT_DIFFERENCE);
6684 if (result == nullptr) {
6685 if (expr->IsVar() && expr->Min() != kint64min &&
6686 !SubOverflows(value, expr->Min()) &&
6687 !SubOverflows(value, expr->Max())) {
6688 IntVar* const var = expr->Var();
6689 switch (var->VarType()) {
6690 case VAR_ADD_CST: {
6691 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6692 IntVar* const sub_var = add_var->SubVar();
6693 const int64 new_constant = value - add_var->Constant();
6694 if (new_constant == 0) {
6695 result = sub_var;
6696 } else {
6697 result = RegisterIntExpr(
6698 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6699 }
6700 break;
6701 }
6702 case CST_SUB_VAR: {
6703 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6704 IntVar* const sub_var = add_var->SubVar();
6705 const int64 new_constant = value - add_var->Constant();
6706 result = MakeSum(sub_var, new_constant);
6707 break;
6708 }
6709 case OPP_VAR: {
6710 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6711 IntVar* const sub_var = add_var->SubVar();
6712 result = MakeSum(sub_var, value);
6713 break;
6714 }
6715 default:
6716 result =
6717 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, var, value)));
6718 }
6719 } else {
6720 result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, expr, value)));
6721 }
6722 Cache()->InsertExprConstantExpression(result, expr, value,
6723 ModelCache::EXPR_CONSTANT_DIFFERENCE);
6724 }
6725 return result;
6726}
6727
6728IntExpr* Solver::MakeOpposite(IntExpr* const expr) {
6729 CHECK_EQ(this, expr->solver());
6730 if (expr->Bound()) {
6731 return MakeIntConst(-expr->Min());
6732 }
6733 IntExpr* result =
6734 Cache()->FindExprExpression(expr, ModelCache::EXPR_OPPOSITE);
6735 if (result == nullptr) {
6736 if (expr->IsVar()) {
6737 result = RegisterIntVar(RevAlloc(new OppIntExpr(this, expr))->Var());
6738 } else {
6739 result = RegisterIntExpr(RevAlloc(new OppIntExpr(this, expr)));
6740 }
6741 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_OPPOSITE);
6742 }
6743 return result;
6744}
6745
6746IntExpr* Solver::MakeProd(IntExpr* const expr, int64 value) {
6747 CHECK_EQ(this, expr->solver());
6748 IntExpr* result = Cache()->FindExprConstantExpression(
6749 expr, value, ModelCache::EXPR_CONSTANT_PROD);
6750 if (result != nullptr) {
6751 return result;
6752 } else {
6753 IntExpr* m_expr = nullptr;
6754 int64 coefficient = 1;
6755 if (IsProduct(expr, &m_expr, &coefficient)) {
6756 coefficient *= value;
6757 } else {
6758 m_expr = expr;
6760 }
6761 if (m_expr->Bound()) {
6762 return MakeIntConst(coefficient * m_expr->Min());
6763 } else if (coefficient == 1) {
6764 return m_expr;
6765 } else if (coefficient == -1) {
6766 return MakeOpposite(m_expr);
6767 } else if (coefficient > 0) {
6768 if (m_expr->Max() > kint64max / coefficient ||
6769 m_expr->Min() < kint64min / coefficient) {
6770 result = RegisterIntExpr(
6771 RevAlloc(new SafeTimesPosIntCstExpr(this, m_expr, coefficient)));
6772 } else {
6773 result = RegisterIntExpr(
6774 RevAlloc(new TimesPosIntCstExpr(this, m_expr, coefficient)));
6775 }
6776 } else if (coefficient == 0) {
6777 result = MakeIntConst(0);
6778 } else { // coefficient < 0.
6779 result = RegisterIntExpr(
6780 RevAlloc(new TimesIntNegCstExpr(this, m_expr, coefficient)));
6781 }
6782 if (m_expr->IsVar() &&
6783 !absl::GetFlag(FLAGS_cp_disable_expression_optimization)) {
6784 result = result->Var();
6785 }
6786 Cache()->InsertExprConstantExpression(result, expr, value,
6787 ModelCache::EXPR_CONSTANT_PROD);
6788 return result;
6789 }
6790}
6791
6792namespace {
6793void ExtractPower(IntExpr** const expr, int64* const exponant) {
6794 if (dynamic_cast<BasePower*>(*expr) != nullptr) {
6795 BasePower* const power = dynamic_cast<BasePower*>(*expr);
6796 *expr = power->expr();
6797 *exponant = power->exponant();
6798 }
6799 if (dynamic_cast<IntSquare*>(*expr) != nullptr) {
6800 IntSquare* const power = dynamic_cast<IntSquare*>(*expr);
6801 *expr = power->expr();
6802 *exponant = 2;
6803 }
6804 if ((*expr)->IsVar()) {
6805 IntVar* const var = (*expr)->Var();
6806 IntExpr* const sub = var->solver()->CastExpression(var);
6807 if (sub != nullptr && dynamic_cast<BasePower*>(sub) != nullptr) {
6808 BasePower* const power = dynamic_cast<BasePower*>(sub);
6809 *expr = power->expr();
6810 *exponant = power->exponant();
6811 }
6812 if (sub != nullptr && dynamic_cast<IntSquare*>(sub) != nullptr) {
6813 IntSquare* const power = dynamic_cast<IntSquare*>(sub);
6814 *expr = power->expr();
6815 *exponant = 2;
6816 }
6817 }
6818}
6819
6820void ExtractProduct(IntExpr** const expr, int64* const coefficient,
6821 bool* modified) {
6822 if (dynamic_cast<TimesCstIntVar*>(*expr) != nullptr) {
6823 TimesCstIntVar* const left_prod = dynamic_cast<TimesCstIntVar*>(*expr);
6824 *coefficient *= left_prod->Constant();
6825 *expr = left_prod->SubVar();
6826 *modified = true;
6827 } else if (dynamic_cast<TimesIntCstExpr*>(*expr) != nullptr) {
6828 TimesIntCstExpr* const left_prod = dynamic_cast<TimesIntCstExpr*>(*expr);
6829 *coefficient *= left_prod->Constant();
6830 *expr = left_prod->Expr();
6831 *modified = true;
6832 }
6833}
6834} // namespace
6835
6836IntExpr* Solver::MakeProd(IntExpr* const left, IntExpr* const right) {
6837 if (left->Bound()) {
6838 return MakeProd(right, left->Min());
6839 }
6840
6841 if (right->Bound()) {
6842 return MakeProd(left, right->Min());
6843 }
6844
6845 // ----- Discover squares and powers -----
6846
6847 IntExpr* m_left = left;
6848 IntExpr* m_right = right;
6849 int64 left_exponant = 1;
6850 int64 right_exponant = 1;
6851 ExtractPower(&m_left, &left_exponant);
6852 ExtractPower(&m_right, &right_exponant);
6853
6854 if (m_left == m_right) {
6855 return MakePower(m_left, left_exponant + right_exponant);
6856 }
6857
6858 // ----- Discover nested products -----
6859
6860 m_left = left;
6861 m_right = right;
6862 int64 coefficient = 1;
6863 bool modified = false;
6864
6865 ExtractProduct(&m_left, &coefficient, &modified);
6866 ExtractProduct(&m_right, &coefficient, &modified);
6867 if (modified) {
6868 return MakeProd(MakeProd(m_left, m_right), coefficient);
6869 }
6870
6871 // ----- Standard build -----
6872
6873 CHECK_EQ(this, left->solver());
6874 CHECK_EQ(this, right->solver());
6875 IntExpr* result = model_cache_->FindExprExprExpression(
6876 left, right, ModelCache::EXPR_EXPR_PROD);
6877 if (result == nullptr) {
6878 result = model_cache_->FindExprExprExpression(right, left,
6879 ModelCache::EXPR_EXPR_PROD);
6880 }
6881 if (result != nullptr) {
6882 return result;
6883 }
6884 if (left->IsVar() && left->Var()->VarType() == BOOLEAN_VAR) {
6885 if (right->Min() >= 0) {
6886 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6887 this, reinterpret_cast<BooleanVar*>(left), right)));
6888 } else {
6889 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6890 this, reinterpret_cast<BooleanVar*>(left), right)));
6891 }
6892 } else if (right->IsVar() &&
6893 reinterpret_cast<IntVar*>(right)->VarType() == BOOLEAN_VAR) {
6894 if (left->Min() >= 0) {
6895 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6896 this, reinterpret_cast<BooleanVar*>(right), left)));
6897 } else {
6898 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6899 this, reinterpret_cast<BooleanVar*>(right), left)));
6900 }
6901 } else if (left->Min() >= 0 && right->Min() >= 0) {
6902 if (CapProd(left->Max(), right->Max()) ==
6903 kint64max) { // Potential overflow.
6904 result =
6905 RegisterIntExpr(RevAlloc(new SafeTimesPosIntExpr(this, left, right)));
6906 } else {
6907 result =
6908 RegisterIntExpr(RevAlloc(new TimesPosIntExpr(this, left, right)));
6909 }
6910 } else {
6911 result = RegisterIntExpr(RevAlloc(new TimesIntExpr(this, left, right)));
6912 }
6913 model_cache_->InsertExprExprExpression(result, left, right,
6914 ModelCache::EXPR_EXPR_PROD);
6915 return result;
6916}
6917
6918IntExpr* Solver::MakeDiv(IntExpr* const numerator, IntExpr* const denominator) {
6919 CHECK(numerator != nullptr);
6920 CHECK(denominator != nullptr);
6921 if (denominator->Bound()) {
6922 return MakeDiv(numerator, denominator->Min());
6923 }
6924 IntExpr* result = model_cache_->FindExprExprExpression(
6925 numerator, denominator, ModelCache::EXPR_EXPR_DIV);
6926 if (result != nullptr) {
6927 return result;
6928 }
6929
6930 if (denominator->Min() <= 0 && denominator->Max() >= 0) {
6931 AddConstraint(MakeNonEquality(denominator, 0));
6932 }
6933
6934 if (denominator->Min() >= 0) {
6935 if (numerator->Min() >= 0) {
6936 result = RevAlloc(new DivPosPosIntExpr(this, numerator, denominator));
6937 } else {
6938 result = RevAlloc(new DivPosIntExpr(this, numerator, denominator));
6939 }
6940 } else if (denominator->Max() <= 0) {
6941 if (numerator->Max() <= 0) {
6942 result = RevAlloc(new DivPosPosIntExpr(this, MakeOpposite(numerator),
6943 MakeOpposite(denominator)));
6944 } else {
6945 result = MakeOpposite(RevAlloc(
6946 new DivPosIntExpr(this, numerator, MakeOpposite(denominator))));
6947 }
6948 } else {
6949 result = RevAlloc(new DivIntExpr(this, numerator, denominator));
6950 }
6951 model_cache_->InsertExprExprExpression(result, numerator, denominator,
6952 ModelCache::EXPR_EXPR_DIV);
6953 return result;
6954}
6955
6956IntExpr* Solver::MakeDiv(IntExpr* const expr, int64 value) {
6957 CHECK(expr != nullptr);
6958 CHECK_EQ(this, expr->solver());
6959 if (expr->Bound()) {
6960 return MakeIntConst(expr->Min() / value);
6961 } else if (value == 1) {
6962 return expr;
6963 } else if (value == -1) {
6964 return MakeOpposite(expr);
6965 } else if (value > 0) {
6966 return RegisterIntExpr(RevAlloc(new DivPosIntCstExpr(this, expr, value)));
6967 } else if (value == 0) {
6968 LOG(FATAL) << "Cannot divide by 0";
6969 return nullptr;
6970 } else {
6971 return RegisterIntExpr(
6972 MakeOpposite(RevAlloc(new DivPosIntCstExpr(this, expr, -value))));
6973 // TODO(user) : implement special case.
6974 }
6975}
6976
6977Constraint* Solver::MakeAbsEquality(IntVar* const var, IntVar* const abs_var) {
6978 if (Cache()->FindExprExpression(var, ModelCache::EXPR_ABS) == nullptr) {
6979 Cache()->InsertExprExpression(abs_var, var, ModelCache::EXPR_ABS);
6980 }
6981 return RevAlloc(new IntAbsConstraint(this, var, abs_var));
6982}
6983
6984IntExpr* Solver::MakeAbs(IntExpr* const e) {
6985 CHECK_EQ(this, e->solver());
6986 if (e->Min() >= 0) {
6987 return e;
6988 } else if (e->Max() <= 0) {
6989 return MakeOpposite(e);
6990 }
6991 IntExpr* result = Cache()->FindExprExpression(e, ModelCache::EXPR_ABS);
6992 if (result == nullptr) {
6993 int64 coefficient = 1;
6994 IntExpr* expr = nullptr;
6995 if (IsProduct(e, &expr, &coefficient)) {
6996 result = MakeProd(MakeAbs(expr), std::abs(coefficient));
6997 } else {
6998 result = RegisterIntExpr(RevAlloc(new IntAbs(this, e)));
6999 }
7000 Cache()->InsertExprExpression(result, e, ModelCache::EXPR_ABS);
7001 }
7002 return result;
7003}
7004
7005IntExpr* Solver::MakeSquare(IntExpr* const expr) {
7006 CHECK_EQ(this, expr->solver());
7007 if (expr->Bound()) {
7008 const int64 v = expr->Min();
7009 return MakeIntConst(v * v);
7010 }
7011 IntExpr* result = Cache()->FindExprExpression(expr, ModelCache::EXPR_SQUARE);
7012 if (result == nullptr) {
7013 if (expr->Min() >= 0) {
7014 result = RegisterIntExpr(RevAlloc(new PosIntSquare(this, expr)));
7015 } else {
7016 result = RegisterIntExpr(RevAlloc(new IntSquare(this, expr)));
7017 }
7018 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_SQUARE);
7019 }
7020 return result;
7021}
7022
7023IntExpr* Solver::MakePower(IntExpr* const expr, int64 n) {
7024 CHECK_EQ(this, expr->solver());
7025 CHECK_GE(n, 0);
7026 if (expr->Bound()) {
7027 const int64 v = expr->Min();
7028 if (v >= OverflowLimit(n)) { // Overflow.
7029 return MakeIntConst(kint64max);
7030 }
7031 return MakeIntConst(IntPower(v, n));
7032 }
7033 switch (n) {
7034 case 0:
7035 return MakeIntConst(1);
7036 case 1:
7037 return expr;
7038 case 2:
7039 return MakeSquare(expr);
7040 default: {
7041 IntExpr* result = nullptr;
7042 if (n % 2 == 0) { // even.
7043 if (expr->Min() >= 0) {
7044 result =
7045 RegisterIntExpr(RevAlloc(new PosIntEvenPower(this, expr, n)));
7046 } else {
7047 result = RegisterIntExpr(RevAlloc(new IntEvenPower(this, expr, n)));
7048 }
7049 } else {
7050 result = RegisterIntExpr(RevAlloc(new IntOddPower(this, expr, n)));
7051 }
7052 return result;
7053 }
7054 }
7055}
7056
7057IntExpr* Solver::MakeMin(IntExpr* const left, IntExpr* const right) {
7058 CHECK_EQ(this, left->solver());
7059 CHECK_EQ(this, right->solver());
7060 if (left->Bound()) {
7061 return MakeMin(right, left->Min());
7062 }
7063 if (right->Bound()) {
7064 return MakeMin(left, right->Min());
7065 }
7066 if (left->Min() >= right->Max()) {
7067 return right;
7068 }
7069 if (right->Min() >= left->Max()) {
7070 return left;
7071 }
7072 return RegisterIntExpr(RevAlloc(new MinIntExpr(this, left, right)));
7073}
7074
7075IntExpr* Solver::MakeMin(IntExpr* const expr, int64 value) {
7076 CHECK_EQ(this, expr->solver());
7077 if (value <= expr->Min()) {
7078 return MakeIntConst(value);
7079 }
7080 if (expr->Bound()) {
7081 return MakeIntConst(std::min(expr->Min(), value));
7082 }
7083 if (expr->Max() <= value) {
7084 return expr;
7085 }
7086 return RegisterIntExpr(RevAlloc(new MinCstIntExpr(this, expr, value)));
7087}
7088
7089IntExpr* Solver::MakeMin(IntExpr* const expr, int value) {
7090 return MakeMin(expr, static_cast<int64>(value));
7091}
7092
7093IntExpr* Solver::MakeMax(IntExpr* const left, IntExpr* const right) {
7094 CHECK_EQ(this, left->solver());
7095 CHECK_EQ(this, right->solver());
7096 if (left->Bound()) {
7097 return MakeMax(right, left->Min());
7098 }
7099 if (right->Bound()) {
7100 return MakeMax(left, right->Min());
7101 }
7102 if (left->Min() >= right->Max()) {
7103 return left;
7104 }
7105 if (right->Min() >= left->Max()) {
7106 return right;
7107 }
7108 return RegisterIntExpr(RevAlloc(new MaxIntExpr(this, left, right)));
7109}
7110
7111IntExpr* Solver::MakeMax(IntExpr* const expr, int64 value) {
7112 CHECK_EQ(this, expr->solver());
7113 if (expr->Bound()) {
7114 return MakeIntConst(std::max(expr->Min(), value));
7115 }
7116 if (value <= expr->Min()) {
7117 return expr;
7118 }
7119 if (expr->Max() <= value) {
7120 return MakeIntConst(value);
7121 }
7122 return RegisterIntExpr(RevAlloc(new MaxCstIntExpr(this, expr, value)));
7123}
7124
7125IntExpr* Solver::MakeMax(IntExpr* const expr, int value) {
7126 return MakeMax(expr, static_cast<int64>(value));
7127}
7128
7129IntExpr* Solver::MakeConvexPiecewiseExpr(IntExpr* expr, int64 early_cost,
7130 int64 early_date, int64 late_date,
7131 int64 late_cost) {
7132 return RegisterIntExpr(RevAlloc(new SimpleConvexPiecewiseExpr(
7133 this, expr, early_cost, early_date, late_date, late_cost)));
7134}
7135
7136IntExpr* Solver::MakeSemiContinuousExpr(IntExpr* const expr, int64 fixed_charge,
7137 int64 step) {
7138 if (step == 0) {
7139 if (fixed_charge == 0) {
7140 return MakeIntConst(int64{0});
7141 } else {
7142 return RegisterIntExpr(
7143 RevAlloc(new SemiContinuousStepZeroExpr(this, expr, fixed_charge)));
7144 }
7145 } else if (step == 1) {
7146 return RegisterIntExpr(
7147 RevAlloc(new SemiContinuousStepOneExpr(this, expr, fixed_charge)));
7148 } else {
7149 return RegisterIntExpr(
7150 RevAlloc(new SemiContinuousExpr(this, expr, fixed_charge, step)));
7151 }
7152 // TODO(user) : benchmark with virtualization of
7153 // PosIntDivDown and PosIntDivUp - or function pointers.
7154}
7155
7156// ----- Piecewise Linear -----
7157
7159 public:
7161 const PiecewiseLinearFunction& f)
7162 : BaseIntExpr(solver), expr_(expr), f_(f) {}
7164 int64 Min() const override {
7165 return f_.GetMinimum(expr_->Min(), expr_->Max());
7166 }
7167 void SetMin(int64 m) override {
7168 const auto& range =
7169 f_.GetSmallestRangeGreaterThanValue(expr_->Min(), expr_->Max(), m);
7170 expr_->SetRange(range.first, range.second);
7171 }
7172
7173 int64 Max() const override {
7174 return f_.GetMaximum(expr_->Min(), expr_->Max());
7175 }
7176
7177 void SetMax(int64 m) override {
7178 const auto& range =
7179 f_.GetSmallestRangeLessThanValue(expr_->Min(), expr_->Max(), m);
7180 expr_->SetRange(range.first, range.second);
7181 }
7182
7183 void SetRange(int64 l, int64 u) override {
7184 const auto& range =
7185 f_.GetSmallestRangeInValueRange(expr_->Min(), expr_->Max(), l, u);
7186 expr_->SetRange(range.first, range.second);
7187 }
7188 std::string name() const override {
7189 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->name(),
7190 f_.DebugString());
7191 }
7192
7193 std::string DebugString() const override {
7194 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->DebugString(),
7195 f_.DebugString());
7196 }
7197
7198 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
7199
7200 void Accept(ModelVisitor* const visitor) const override {
7201 // TODO(user): Implement visitor.
7202 }
7203
7204 private:
7205 IntExpr* const expr_;
7206 const PiecewiseLinearFunction f_;
7207};
7208
7209IntExpr* Solver::MakePiecewiseLinearExpr(IntExpr* expr,
7210 const PiecewiseLinearFunction& f) {
7211 return RegisterIntExpr(RevAlloc(new PiecewiseLinearExpr(this, expr, f)));
7212}
7213
7214// ----- Conditional Expression -----
7215
7216IntExpr* Solver::MakeConditionalExpression(IntVar* const condition,
7217 IntExpr* const expr,
7218 int64 unperformed_value) {
7219 if (condition->Min() == 1) {
7220 return expr;
7221 } else if (condition->Max() == 0) {
7222 return MakeIntConst(unperformed_value);
7223 } else {
7224 IntExpr* cache = Cache()->FindExprExprConstantExpression(
7225 condition, expr, unperformed_value,
7226 ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7227 if (cache == nullptr) {
7228 cache = RevAlloc(
7229 new ExprWithEscapeValue(this, condition, expr, unperformed_value));
7230 Cache()->InsertExprExprConstantExpression(
7231 cache, condition, expr, unperformed_value,
7232 ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7233 }
7234 return cache;
7235 }
7236}
7237
7238// ----- Modulo -----
7239
7240IntExpr* Solver::MakeModulo(IntExpr* const x, int64 mod) {
7241 IntVar* const result =
7242 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7243 if (mod >= 0) {
7244 AddConstraint(MakeBetweenCt(result, 0, mod - 1));
7245 } else {
7246 AddConstraint(MakeBetweenCt(result, mod + 1, 0));
7247 }
7248 return result;
7249}
7250
7251IntExpr* Solver::MakeModulo(IntExpr* const x, IntExpr* const mod) {
7252 if (mod->Bound()) {
7253 return MakeModulo(x, mod->Min());
7254 }
7255 IntVar* const result =
7256 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7257 AddConstraint(MakeLess(result, MakeAbs(mod)));
7258 AddConstraint(MakeGreater(result, MakeOpposite(MakeAbs(mod))));
7259 return result;
7260}
7261
7262// --------- IntVar ---------
7263
7264int IntVar::VarType() const { return UNSPECIFIED; }
7265
7266void IntVar::RemoveValues(const std::vector<int64>& values) {
7267 // TODO(user): Check and maybe inline this code.
7268 const int size = values.size();
7269 DCHECK_GE(size, 0);
7270 switch (size) {
7271 case 0: {
7272 return;
7273 }
7274 case 1: {
7275 RemoveValue(values[0]);
7276 return;
7277 }
7278 case 2: {
7279 RemoveValue(values[0]);
7280 RemoveValue(values[1]);
7281 return;
7282 }
7283 case 3: {
7284 RemoveValue(values[0]);
7285 RemoveValue(values[1]);
7286 RemoveValue(values[2]);
7287 return;
7288 }
7289 default: {
7290 // 4 values, let's start doing some more clever things.
7291 // TODO(user) : Sort values!
7292 int start_index = 0;
7293 int64 new_min = Min();
7294 if (values[start_index] <= new_min) {
7295 while (start_index < size - 1 &&
7296 values[start_index + 1] == values[start_index] + 1) {
7297 new_min = values[start_index + 1] + 1;
7298 start_index++;
7299 }
7300 }
7301 int end_index = size - 1;
7302 int64 new_max = Max();
7303 if (values[end_index] >= new_max) {
7304 while (end_index > start_index + 1 &&
7305 values[end_index - 1] == values[end_index] - 1) {
7306 new_max = values[end_index - 1] - 1;
7307 end_index--;
7308 }
7309 }
7310 SetRange(new_min, new_max);
7311 for (int i = start_index; i <= end_index; ++i) {
7312 RemoveValue(values[i]);
7313 }
7314 }
7315 }
7316}
7317
7318void IntVar::Accept(ModelVisitor* const visitor) const {
7319 IntExpr* const casted = solver()->CastExpression(this);
7320 visitor->VisitIntegerVariable(this, casted);
7321}
7322
7323void IntVar::SetValues(const std::vector<int64>& values) {
7324 switch (values.size()) {
7325 case 0: {
7326 solver()->Fail();
7327 break;
7328 }
7329 case 1: {
7330 SetValue(values.back());
7331 break;
7332 }
7333 case 2: {
7334 if (Contains(values[0])) {
7335 if (Contains(values[1])) {
7336 const int64 l = std::min(values[0], values[1]);
7337 const int64 u = std::max(values[0], values[1]);
7338 SetRange(l, u);
7339 if (u > l + 1) {
7340 RemoveInterval(l + 1, u - 1);
7341 }
7342 } else {
7343 SetValue(values[0]);
7344 }
7345 } else {
7346 SetValue(values[1]);
7347 }
7348 break;
7349 }
7350 default: {
7351 // TODO(user): use a clean and safe SortedUniqueCopy() class
7352 // that uses a global, static shared (and locked) storage.
7353 // TODO(user): [optional] consider porting
7354 // STLSortAndRemoveDuplicates from ortools/base/stl_util.h to the
7355 // existing open_source/base/stl_util.h and using it here.
7356 // TODO(user): We could filter out values not in the var.
7357 std::vector<int64>& tmp = solver()->tmp_vector_;
7358 tmp.clear();
7359 tmp.insert(tmp.end(), values.begin(), values.end());
7360 std::sort(tmp.begin(), tmp.end());
7361 tmp.erase(std::unique(tmp.begin(), tmp.end()), tmp.end());
7362 const int size = tmp.size();
7363 const int64 vmin = Min();
7364 const int64 vmax = Max();
7365 int first = 0;
7366 int last = size - 1;
7367 if (tmp.front() > vmax || tmp.back() < vmin) {
7368 solver()->Fail();
7369 }
7370 // TODO(user) : We could find the first position >= vmin by dichotomy.
7371 while (tmp[first] < vmin || !Contains(tmp[first])) {
7372 ++first;
7373 if (first > last || tmp[first] > vmax) {
7374 solver()->Fail();
7375 }
7376 }
7377 while (last > first && (tmp[last] > vmax || !Contains(tmp[last]))) {
7378 // Note that last >= first implies tmp[last] >= vmin.
7379 --last;
7380 }
7381 DCHECK_GE(last, first);
7382 SetRange(tmp[first], tmp[last]);
7383 while (first < last) {
7384 const int64 start = tmp[first] + 1;
7385 const int64 end = tmp[first + 1] - 1;
7386 if (start <= end) {
7387 RemoveInterval(start, end);
7388 }
7389 first++;
7390 }
7391 }
7392 }
7393}
7394// ---------- BaseIntExpr ---------
7395
7396void LinkVarExpr(Solver* const s, IntExpr* const expr, IntVar* const var) {
7397 if (!var->Bound()) {
7398 if (var->VarType() == DOMAIN_INT_VAR) {
7399 DomainIntVar* dvar = reinterpret_cast<DomainIntVar*>(var);
7401 s->RevAlloc(new LinkExprAndDomainIntVar(s, expr, dvar)), dvar, expr);
7402 } else {
7403 s->AddCastConstraint(s->RevAlloc(new LinkExprAndVar(s, expr, var)), var,
7404 expr);
7405 }
7406 }
7407}
7408
7409IntVar* BaseIntExpr::Var() {
7410 if (var_ == nullptr) {
7411 solver()->SaveValue(reinterpret_cast<void**>(&var_));
7412 var_ = CastToVar();
7413 }
7414 return var_;
7415}
7416
7417IntVar* BaseIntExpr::CastToVar() {
7418 int64 vmin, vmax;
7419 Range(&vmin, &vmax);
7420 IntVar* const var = solver()->MakeIntVar(vmin, vmax);
7421 LinkVarExpr(solver(), this, var);
7422 return var;
7423}
7424
7425// Discovery methods
7426bool Solver::IsADifference(IntExpr* expr, IntExpr** const left,
7427 IntExpr** const right) {
7428 if (expr->IsVar()) {
7429 IntVar* const expr_var = expr->Var();
7430 expr = CastExpression(expr_var);
7431 }
7432 // This is a dynamic cast to check the type of expr.
7433 // It returns nullptr is expr is not a subclass of SubIntExpr.
7434 SubIntExpr* const sub_expr = dynamic_cast<SubIntExpr*>(expr);
7435 if (sub_expr != nullptr) {
7436 *left = sub_expr->left();
7437 *right = sub_expr->right();
7438 return true;
7439 }
7440 return false;
7441}
7442
7443bool Solver::IsBooleanVar(IntExpr* const expr, IntVar** inner_var,
7444 bool* is_negated) const {
7445 if (expr->IsVar() && expr->Var()->VarType() == BOOLEAN_VAR) {
7446 *inner_var = expr->Var();
7447 *is_negated = false;
7448 return true;
7449 } else if (expr->IsVar() && expr->Var()->VarType() == CST_SUB_VAR) {
7450 SubCstIntVar* const sub_var = reinterpret_cast<SubCstIntVar*>(expr);
7451 if (sub_var != nullptr && sub_var->Constant() == 1 &&
7452 sub_var->SubVar()->VarType() == BOOLEAN_VAR) {
7453 *is_negated = true;
7454 *inner_var = sub_var->SubVar();
7455 return true;
7456 }
7457 }
7458 return false;
7459}
7460
7461bool Solver::IsProduct(IntExpr* const expr, IntExpr** inner_expr,
7462 int64* coefficient) {
7463 if (dynamic_cast<TimesCstIntVar*>(expr) != nullptr) {
7464 TimesCstIntVar* const var = dynamic_cast<TimesCstIntVar*>(expr);
7465 *coefficient = var->Constant();
7466 *inner_expr = var->SubVar();
7467 return true;
7468 } else if (dynamic_cast<TimesIntCstExpr*>(expr) != nullptr) {
7469 TimesIntCstExpr* const prod = dynamic_cast<TimesIntCstExpr*>(expr);
7470 *coefficient = prod->Constant();
7471 *inner_expr = prod->Expr();
7472 return true;
7473 }
7474 *inner_expr = expr;
7475 *coefficient = 1;
7476 return false;
7477}
7478
7479#undef COND_REV_ALLOC
7480
7481} // namespace operations_research
int64 min
Definition: alldiff_cst.cc:138
int64 max
Definition: alldiff_cst.cc:139
#define CHECK(condition)
Definition: base/logging.h:495
#define DCHECK_LE(val1, val2)
Definition: base/logging.h:887
#define DCHECK_NE(val1, val2)
Definition: base/logging.h:886
#define CHECK_LT(val1, val2)
Definition: base/logging.h:700
#define CHECK_EQ(val1, val2)
Definition: base/logging.h:697
#define CHECK_GE(val1, val2)
Definition: base/logging.h:701
#define CHECK_GT(val1, val2)
Definition: base/logging.h:702
#define DCHECK_GE(val1, val2)
Definition: base/logging.h:889
#define CHECK_NE(val1, val2)
Definition: base/logging.h:698
#define DCHECK_GT(val1, val2)
Definition: base/logging.h:890
#define DCHECK_LT(val1, val2)
Definition: base/logging.h:888
#define LOG(severity)
Definition: base/logging.h:420
#define DCHECK(condition)
Definition: base/logging.h:884
#define DCHECK_EQ(val1, val2)
Definition: base/logging.h:885
This is the base class for all expressions that are not variables.
A BaseObject is the root of all reversibly allocated objects.
IntVar * IsDifferent(int64 constant) override
Definition: expressions.cc:143
void RemoveInterval(int64 l, int64 u) override
This method removes the interval 'l' .
Definition: expressions.cc:103
void SetMax(int64 m) override
Definition: expressions.cc:74
IntVar * IsLessOrEqual(int64 constant) override
Definition: expressions.cc:164
void WhenBound(Demon *d) override
This method attaches a demon that will be awakened when the variable is bound.
Definition: expressions.cc:114
bool Contains(int64 v) const override
This method returns whether the value 'v' is in the domain of the variable.
Definition: expressions.cc:128
SimpleRevFIFO< Demon * > delayed_bound_demons_
void SetRange(int64 mi, int64 ma) override
This method sets both the min and the max of the expression.
Definition: expressions.cc:80
void RemoveValue(int64 v) override
This method removes the value 'v' from the domain of the variable.
Definition: expressions.cc:91
IntVar * IsEqual(int64 constant) override
IsEqual.
Definition: expressions.cc:132
IntVar * IsGreaterOrEqual(int64 constant) override
Definition: expressions.cc:154
SimpleRevFIFO< Demon * > bound_demons_
void SetMin(int64 m) override
Definition: expressions.cc:68
std::string DebugString() const override
Definition: expressions.cc:174
uint64 Size() const override
This method returns the number of values in the domain of the variable.
Definition: expressions.cc:124
A constraint is the main modeling object.
A Demon is the base element of a propagation queue.
virtual Solver::DemonPriority priority() const
This method returns the priority of the demon.
The class IntExpr is the base of all integer expressions in constraint programming.
virtual void SetValue(int64 v)
This method sets the value of the expression.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual bool IsVar() const
Returns true if the expression is indeed a variable.
virtual int64 Max() const =0
virtual IntVar * Var()=0
Creates a variable from the expression.
IntVar * VarWithName(const std::string &name)
Creates a variable from the expression and set the name of the resulting var.
Definition: expressions.cc:49
virtual int64 Min() const =0
The class IntVar is a subset of IntExpr.
IntVar(Solver *const s)
Definition: expressions.cc:57
IntVar * Var() override
Creates a variable from the expression.
virtual int VarType() const
The class Iterator has two direct subclasses.
virtual void VisitIntegerVariable(const IntVar *const variable, IntExpr *const delegate)
void SetRange(int64 l, int64 u) override
This method sets both the min and the max of the expression.
PiecewiseLinearExpr(Solver *solver, IntExpr *expr, const PiecewiseLinearFunction &f)
void WhenRange(Demon *d) override
Attach a demon that will watch the min or the max of the expression.
void Accept(ModelVisitor *const visitor) const override
Accepts the given visitor.
std::string name() const override
Object naming.
std::string DebugString() const override
virtual std::string name() const
Object naming.
void SetValue(Solver *const s, const T &val)
DemonPriority
This enum represents the three possible priorities for a demon in the Solver queue.
@ VAR_PRIORITY
VAR_PRIORITY is between DELAYED_PRIORITY and NORMAL_PRIORITY.
@ DELAYED_PRIORITY
DELAYED_PRIORITY is the lowest priority: Demons will be processed after VAR_PRIORITY and NORMAL_PRIOR...
@ OUTSIDE_SEARCH
Before search, after search.
IntExpr * MakeDifference(IntExpr *const left, IntExpr *const right)
left - right
void AddCastConstraint(CastConstraint *const constraint, IntVar *const target_var, IntExpr *const expr)
Adds 'constraint' to the solver and marks it as a cast constraint, that is, a constraint created call...
void Fail()
Abandon the current branch in the search tree. A backtrack will follow.
T * RevAlloc(T *object)
Registers the given object as being reversible.
IntVar * MakeIntConst(int64 val, const std::string &name)
IntConst will create a constant expression.
std::vector< IntVarIterator * > holes_
const std::string name
const Constraint * ct
int64 value
IntVar *const expr_
Definition: element.cc:85
IntVar * var
Definition: expr_array.cc:1858
const int64 limit_
#define COND_REV_ALLOC(rev, alloc)
Solver *const solver_
Definition: expressions.cc:274
const int64 pow_
const int64 cst_
ABSL_FLAG(bool, cp_disable_expression_optimization, false, "Disable special optimization when creating expressions.")
IntVarIterator *const iterator_
static const int64 kint64max
int64_t int64
static const int32 kint32max
uint64_t uint64
static const int64 kint64min
const int64 offset_
Definition: interval.cc:2076
Handler handler_
Definition: interval.cc:420
bool in_process_
Definition: interval.cc:419
const int FATAL
Definition: log_severity.h:32
#define DISALLOW_COPY_AND_ASSIGN(TypeName)
Definition: macros.h:29
int RemoveAt(RepeatedType *array, const IndexContainer &indices)
Definition: protobuf_util.h:40
const Collection::value_type::second_type FindPtrOrNull(const Collection &collection, const typename Collection::value_type::first_type &key)
Definition: map_util.h:70
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
Definition: stl_util.h:58
std::function< int64(const Model &)> Value(IntegerVariable v)
Definition: integer.h:1487
The vehicle routing library lets one model and solve generic vehicle routing problems ranging from th...
int64 PosIntDivUp(int64 e, int64 v)
int64 CapAdd(int64 x, int64 y)
int64 UnsafeLeastSignificantBitPosition64(const uint64 *const bitset, uint64 start, uint64 end)
int64 CapProd(int64 x, int64 y)
int64 PosIntDivDown(int64 e, int64 v)
void CleanVariableOnFail(IntVar *const var)
uint64 BitLength64(uint64 size)
Definition: bitset.h:338
int64 CapSub(int64 x, int64 y)
bool IsBitSet64(const uint64 *const bitset, uint64 pos)
Definition: bitset.h:346
int MostSignificantBitPosition64(uint64 n)
Definition: bitset.h:231
uint32 BitPos64(uint64 pos)
Definition: bitset.h:330
bool AddOverflows(int64 x, int64 y)
Constraint * SetIsGreaterOrEqual(IntVar *const var, const std::vector< int64 > &values, const std::vector< IntVar * > &vars)
uint64 OneRange64(uint64 s, uint64 e)
Definition: bitset.h:285
static const uint64 kAllBits64
Definition: bitset.h:33
uint64 OneBit64(int pos)
Definition: bitset.h:38
void RestoreBoolValue(IntVar *const var)
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
Constraint * SetIsEqual(IntVar *const var, const std::vector< int64 > &values, const std::vector< IntVar * > &vars)
uint64 BitOffset64(uint64 pos)
Definition: bitset.h:334
std::vector< int64 > ToInt64Vector(const std::vector< int > &input)
Definition: utilities.cc:822
void LinkVarExpr(Solver *const s, IntExpr *const expr, IntVar *const var)
int LeastSignificantBitPosition64(uint64 n)
Definition: bitset.h:127
void RegisterDemon(Solver *const solver, Demon *const demon, DemonProfiler *const monitor)
uint64 BitCountRange64(const uint64 *const bitset, uint64 start, uint64 end)
int64 SubOverflows(int64 x, int64 y)
void InternalSaveBooleanVarValue(Solver *const solver, IntVar *const var)
uint64 BitCount64(uint64 n)
Definition: bitset.h:42
int64 UnsafeMostSignificantBitPosition64(const uint64 *const bitset, uint64 start, uint64 end)
int index
Definition: pack.cc:508
int64 coefficient
IntervalVar *const target_var_
int64 step_
Definition: search.cc:2952
const int64 stamp_
Definition: search.cc:3039
int64 current_
Definition: search.cc:2953