OR-Tools  8.2
element.cc
Go to the documentation of this file.
1// Copyright 2010-2018 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#include <algorithm>
15#include <memory>
16#include <numeric>
17#include <string>
18#include <utility>
19#include <vector>
20
21#include "absl/strings/str_format.h"
22#include "absl/strings/str_join.h"
29
30ABSL_FLAG(bool, cp_disable_element_cache, true,
31 "If true, caching for IntElement is disabled.");
32
33namespace operations_research {
34
35// ----- IntExprElement -----
36void LinkVarExpr(Solver* const s, IntExpr* const expr, IntVar* const var);
37
38namespace {
39
40template <class T>
41class VectorLess {
42 public:
43 explicit VectorLess(const std::vector<T>* values) : values_(values) {}
44 bool operator()(const T& x, const T& y) const {
45 return (*values_)[x] < (*values_)[y];
46 }
47
48 private:
49 const std::vector<T>* values_;
50};
51
52template <class T>
53class VectorGreater {
54 public:
55 explicit VectorGreater(const std::vector<T>* values) : values_(values) {}
56 bool operator()(const T& x, const T& y) const {
57 return (*values_)[x] > (*values_)[y];
58 }
59
60 private:
61 const std::vector<T>* values_;
62};
63
64// ----- BaseIntExprElement -----
65
66class BaseIntExprElement : public BaseIntExpr {
67 public:
68 BaseIntExprElement(Solver* const s, IntVar* const e);
69 ~BaseIntExprElement() override {}
70 int64 Min() const override;
71 int64 Max() const override;
72 void Range(int64* mi, int64* ma) override;
73 void SetMin(int64 m) override;
74 void SetMax(int64 m) override;
75 void SetRange(int64 mi, int64 ma) override;
76 bool Bound() const override { return (expr_->Bound()); }
77 // TODO(user) : improve me, the previous test is not always true
78 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
79
80 protected:
81 virtual int64 ElementValue(int index) const = 0;
82 virtual int64 ExprMin() const = 0;
83 virtual int64 ExprMax() const = 0;
84
85 IntVar* const expr_;
86
87 private:
88 void UpdateSupports() const;
89
90 mutable int64 min_;
91 mutable int min_support_;
92 mutable int64 max_;
93 mutable int max_support_;
94 mutable bool initial_update_;
95 IntVarIterator* const expr_iterator_;
96};
97
98BaseIntExprElement::BaseIntExprElement(Solver* const s, IntVar* const e)
99 : BaseIntExpr(s),
100 expr_(e),
101 min_(0),
102 min_support_(-1),
103 max_(0),
104 max_support_(-1),
105 initial_update_(true),
106 expr_iterator_(expr_->MakeDomainIterator(true)) {
107 CHECK(s != nullptr);
108 CHECK(e != nullptr);
109}
110
111int64 BaseIntExprElement::Min() const {
112 UpdateSupports();
113 return min_;
114}
115
116int64 BaseIntExprElement::Max() const {
117 UpdateSupports();
118 return max_;
119}
120
121void BaseIntExprElement::Range(int64* mi, int64* ma) {
122 UpdateSupports();
123 *mi = min_;
124 *ma = max_;
125}
126
127#define UPDATE_BASE_ELEMENT_INDEX_BOUNDS(test) \
128 const int64 emin = ExprMin(); \
129 const int64 emax = ExprMax(); \
130 int64 nmin = emin; \
131 int64 value = ElementValue(nmin); \
132 while (nmin < emax && test) { \
133 nmin++; \
134 value = ElementValue(nmin); \
135 } \
136 if (nmin == emax && test) { \
137 solver()->Fail(); \
138 } \
139 int64 nmax = emax; \
140 value = ElementValue(nmax); \
141 while (nmax >= nmin && test) { \
142 nmax--; \
143 value = ElementValue(nmax); \
144 } \
145 expr_->SetRange(nmin, nmax);
146
147void BaseIntExprElement::SetMin(int64 m) {
149}
150
151void BaseIntExprElement::SetMax(int64 m) {
153}
154
155void BaseIntExprElement::SetRange(int64 mi, int64 ma) {
156 if (mi > ma) {
157 solver()->Fail();
158 }
159 UPDATE_BASE_ELEMENT_INDEX_BOUNDS((value < mi || value > ma));
160}
161
162#undef UPDATE_BASE_ELEMENT_INDEX_BOUNDS
163
164void BaseIntExprElement::UpdateSupports() const {
165 if (initial_update_ || !expr_->Contains(min_support_) ||
166 !expr_->Contains(max_support_)) {
167 const int64 emin = ExprMin();
168 const int64 emax = ExprMax();
169 int64 min_value = ElementValue(emax);
170 int64 max_value = min_value;
171 int min_support = emax;
172 int max_support = emax;
173 const uint64 expr_size = expr_->Size();
174 if (expr_size > 1) {
175 if (expr_size == emax - emin + 1) {
176 // Value(emax) already stored in min_value, max_value.
177 for (int64 index = emin; index < emax; ++index) {
178 const int64 value = ElementValue(index);
179 if (value > max_value) {
180 max_value = value;
181 max_support = index;
182 } else if (value < min_value) {
183 min_value = value;
184 min_support = index;
185 }
186 }
187 } else {
188 for (const int64 index : InitAndGetValues(expr_iterator_)) {
189 if (index >= emin && index <= emax) {
190 const int64 value = ElementValue(index);
191 if (value > max_value) {
192 max_value = value;
193 max_support = index;
194 } else if (value < min_value) {
195 min_value = value;
196 min_support = index;
197 }
198 }
199 }
200 }
201 }
202 Solver* s = solver();
203 s->SaveAndSetValue(&min_, min_value);
204 s->SaveAndSetValue(&min_support_, min_support);
205 s->SaveAndSetValue(&max_, max_value);
206 s->SaveAndSetValue(&max_support_, max_support);
207 s->SaveAndSetValue(&initial_update_, false);
208 }
209}
210
211// ----- IntElementConstraint -----
212
213// This constraint implements 'elem' == 'values'['index'].
214// It scans the bounds of 'elem' to propagate on the domain of 'index'.
215// It scans the domain of 'index' to compute the new bounds of 'elem'.
216class IntElementConstraint : public CastConstraint {
217 public:
218 IntElementConstraint(Solver* const s, const std::vector<int64>& values,
219 IntVar* const index, IntVar* const elem)
220 : CastConstraint(s, elem),
221 values_(values),
222 index_(index),
223 index_iterator_(index_->MakeDomainIterator(true)) {
224 CHECK(index != nullptr);
225 }
226
227 void Post() override {
228 Demon* const d =
229 solver()->MakeDelayedConstraintInitialPropagateCallback(this);
230 index_->WhenDomain(d);
231 target_var_->WhenRange(d);
232 }
233
234 void InitialPropagate() override {
235 index_->SetRange(0, values_.size() - 1);
236 const int64 target_var_min = target_var_->Min();
237 const int64 target_var_max = target_var_->Max();
238 int64 new_min = target_var_max;
239 int64 new_max = target_var_min;
240 to_remove_.clear();
241 for (const int64 index : InitAndGetValues(index_iterator_)) {
242 const int64 value = values_[index];
243 if (value < target_var_min || value > target_var_max) {
244 to_remove_.push_back(index);
245 } else {
246 if (value < new_min) {
247 new_min = value;
248 }
249 if (value > new_max) {
250 new_max = value;
251 }
252 }
253 }
254 target_var_->SetRange(new_min, new_max);
255 if (!to_remove_.empty()) {
256 index_->RemoveValues(to_remove_);
257 }
258 }
259
260 std::string DebugString() const override {
261 return absl::StrFormat("IntElementConstraint(%s, %s, %s)",
262 absl::StrJoin(values_, ", "), index_->DebugString(),
263 target_var_->DebugString());
264 }
265
266 void Accept(ModelVisitor* const visitor) const override {
267 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
268 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
269 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
270 index_);
271 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
273 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
274 }
275
276 private:
277 const std::vector<int64> values_;
278 IntVar* const index_;
279 IntVarIterator* const index_iterator_;
280 std::vector<int64> to_remove_;
281};
282
283// ----- IntExprElement
284
285IntVar* BuildDomainIntVar(Solver* const solver, std::vector<int64>* values);
286
287class IntExprElement : public BaseIntExprElement {
288 public:
289 IntExprElement(Solver* const s, const std::vector<int64>& vals,
290 IntVar* const expr)
291 : BaseIntExprElement(s, expr), values_(vals) {}
292
293 ~IntExprElement() override {}
294
295 std::string name() const override {
296 const int size = values_.size();
297 if (size > 10) {
298 return absl::StrFormat("IntElement(array of size %d, %s)", size,
299 expr_->name());
300 } else {
301 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
302 expr_->name());
303 }
304 }
305
306 std::string DebugString() const override {
307 const int size = values_.size();
308 if (size > 10) {
309 return absl::StrFormat("IntElement(array of size %d, %s)", size,
310 expr_->DebugString());
311 } else {
312 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
313 expr_->DebugString());
314 }
315 }
316
317 IntVar* CastToVar() override {
318 Solver* const s = solver();
319 IntVar* const var = s->MakeIntVar(values_);
320 s->AddCastConstraint(
321 s->RevAlloc(new IntElementConstraint(s, values_, expr_, var)), var,
322 this);
323 return var;
324 }
325
326 void Accept(ModelVisitor* const visitor) const override {
327 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
328 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
329 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
330 expr_);
331 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
332 }
333
334 protected:
335 int64 ElementValue(int index) const override {
336 DCHECK_LT(index, values_.size());
337 return values_[index];
338 }
339 int64 ExprMin() const override { return std::max<int64>(0, expr_->Min()); }
340 int64 ExprMax() const override {
341 return values_.empty() ? 0
342 : std::min<int64>(values_.size() - 1, expr_->Max());
343 }
344
345 private:
346 const std::vector<int64> values_;
347};
348
349// ----- Range Minimum Query-based Element -----
350
351class RangeMinimumQueryExprElement : public BaseIntExpr {
352 public:
353 RangeMinimumQueryExprElement(Solver* solver, const std::vector<int64>& values,
354 IntVar* index);
355 ~RangeMinimumQueryExprElement() override {}
356 int64 Min() const override;
357 int64 Max() const override;
358 void Range(int64* mi, int64* ma) override;
359 void SetMin(int64 m) override;
360 void SetMax(int64 m) override;
361 void SetRange(int64 mi, int64 ma) override;
362 bool Bound() const override { return (index_->Bound()); }
363 // TODO(user) : improve me, the previous test is not always true
364 void WhenRange(Demon* d) override { index_->WhenRange(d); }
365 IntVar* CastToVar() override {
366 // TODO(user): Should we try to make holes in the domain of index_, as we
367 // do here, or should we only propagate bounds as we do in
368 // IncreasingIntExprElement ?
369 IntVar* const var = solver()->MakeIntVar(min_rmq_.array());
370 solver()->AddCastConstraint(solver()->RevAlloc(new IntElementConstraint(
371 solver(), min_rmq_.array(), index_, var)),
372 var, this);
373 return var;
374 }
375 void Accept(ModelVisitor* const visitor) const override {
376 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
377 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
378 min_rmq_.array());
379 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
380 index_);
381 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
382 }
383
384 private:
385 int64 IndexMin() const { return std::max<int64>(0, index_->Min()); }
386 int64 IndexMax() const {
387 return std::min<int64>(min_rmq_.array().size() - 1, index_->Max());
388 }
389
390 IntVar* const index_;
391 const RangeMinimumQuery<int64, std::less<int64>> min_rmq_;
392 const RangeMinimumQuery<int64, std::greater<int64>> max_rmq_;
393};
394
395RangeMinimumQueryExprElement::RangeMinimumQueryExprElement(
396 Solver* solver, const std::vector<int64>& values, IntVar* index)
397 : BaseIntExpr(solver), index_(index), min_rmq_(values), max_rmq_(values) {
398 CHECK(solver != nullptr);
399 CHECK(index != nullptr);
400}
401
402int64 RangeMinimumQueryExprElement::Min() const {
403 return min_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
404}
405
406int64 RangeMinimumQueryExprElement::Max() const {
407 return max_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
408}
409
410void RangeMinimumQueryExprElement::Range(int64* mi, int64* ma) {
411 const int64 range_min = IndexMin();
412 const int64 range_max = IndexMax() + 1;
413 *mi = min_rmq_.GetMinimumFromRange(range_min, range_max);
414 *ma = max_rmq_.GetMinimumFromRange(range_min, range_max);
415}
416
417#define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test) \
418 const std::vector<int64>& values = min_rmq_.array(); \
419 int64 index_min = IndexMin(); \
420 int64 index_max = IndexMax(); \
421 int64 value = values[index_min]; \
422 while (index_min < index_max && (test)) { \
423 index_min++; \
424 value = values[index_min]; \
425 } \
426 if (index_min == index_max && (test)) { \
427 solver()->Fail(); \
428 } \
429 value = values[index_max]; \
430 while (index_max >= index_min && (test)) { \
431 index_max--; \
432 value = values[index_max]; \
433 } \
434 index_->SetRange(index_min, index_max);
435
436void RangeMinimumQueryExprElement::SetMin(int64 m) {
438}
439
440void RangeMinimumQueryExprElement::SetMax(int64 m) {
442}
443
444void RangeMinimumQueryExprElement::SetRange(int64 mi, int64 ma) {
445 if (mi > ma) {
446 solver()->Fail();
447 }
448 UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(value < mi || value > ma);
449}
450
451#undef UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS
452
453// ----- Increasing Element -----
454
455class IncreasingIntExprElement : public BaseIntExpr {
456 public:
457 IncreasingIntExprElement(Solver* const s, const std::vector<int64>& values,
458 IntVar* const index);
459 ~IncreasingIntExprElement() override {}
460
461 int64 Min() const override;
462 void SetMin(int64 m) override;
463 int64 Max() const override;
464 void SetMax(int64 m) override;
465 void SetRange(int64 mi, int64 ma) override;
466 bool Bound() const override { return (index_->Bound()); }
467 // TODO(user) : improve me, the previous test is not always true
468 std::string name() const override {
469 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
470 index_->name());
471 }
472 std::string DebugString() const override {
473 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
474 index_->DebugString());
475 }
476
477 void Accept(ModelVisitor* const visitor) const override {
478 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
479 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
480 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
481 index_);
482 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
483 }
484
485 void WhenRange(Demon* d) override { index_->WhenRange(d); }
486
487 IntVar* CastToVar() override {
488 Solver* const s = solver();
489 IntVar* const var = s->MakeIntVar(values_);
490 LinkVarExpr(s, this, var);
491 return var;
492 }
493
494 private:
495 const std::vector<int64> values_;
496 IntVar* const index_;
497};
498
499IncreasingIntExprElement::IncreasingIntExprElement(
500 Solver* const s, const std::vector<int64>& values, IntVar* const index)
501 : BaseIntExpr(s), values_(values), index_(index) {
502 DCHECK(index);
503 DCHECK(s);
504}
505
506int64 IncreasingIntExprElement::Min() const {
507 const int64 expression_min = std::max<int64>(0, index_->Min());
508 return (expression_min < values_.size() ? values_[expression_min]
509 : kint64max);
510}
511
512void IncreasingIntExprElement::SetMin(int64 m) {
513 const int64 index_min = std::max<int64>(0, index_->Min());
514 const int64 index_max = std::min<int64>(values_.size() - 1, index_->Max());
515
516 if (index_min > index_max || m > values_[index_max]) {
517 solver()->Fail();
518 }
519
520 const std::vector<int64>::const_iterator first =
521 std::lower_bound(values_.begin(), values_.end(), m);
522 const int64 new_index_min = first - values_.begin();
523 index_->SetMin(new_index_min);
524}
525
526int64 IncreasingIntExprElement::Max() const {
527 const int64 expression_max =
528 std::min<int64>(values_.size() - 1, index_->Max());
529 return (expression_max >= 0 ? values_[expression_max] : kint64max);
530}
531
532void IncreasingIntExprElement::SetMax(int64 m) {
533 int64 index_min = std::max<int64>(0, index_->Min());
534 if (m < values_[index_min]) {
535 solver()->Fail();
536 }
537
538 const std::vector<int64>::const_iterator last_after =
539 std::upper_bound(values_.begin(), values_.end(), m);
540 const int64 new_index_max = (last_after - values_.begin()) - 1;
541 index_->SetRange(0, new_index_max);
542}
543
544void IncreasingIntExprElement::SetRange(int64 mi, int64 ma) {
545 if (mi > ma) {
546 solver()->Fail();
547 }
548 const int64 index_min = std::max<int64>(0, index_->Min());
549 const int64 index_max = std::min<int64>(values_.size() - 1, index_->Max());
550
551 if (mi > ma || ma < values_[index_min] || mi > values_[index_max]) {
552 solver()->Fail();
553 }
554
555 const std::vector<int64>::const_iterator first =
556 std::lower_bound(values_.begin(), values_.end(), mi);
557 const int64 new_index_min = first - values_.begin();
558
559 const std::vector<int64>::const_iterator last_after =
560 std::upper_bound(first, values_.end(), ma);
561 const int64 new_index_max = (last_after - values_.begin()) - 1;
562
563 // Assign.
564 index_->SetRange(new_index_min, new_index_max);
565}
566
567// ----- Solver::MakeElement(int array, int var) -----
568IntExpr* BuildElement(Solver* const solver, const std::vector<int64>& values,
569 IntVar* const index) {
570 // Various checks.
571 // Is array constant?
572 if (IsArrayConstant(values, values[0])) {
573 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
574 return solver->MakeIntConst(values[0]);
575 }
576 // Is array built with booleans only?
577 // TODO(user): We could maintain the index of the first one.
578 if (IsArrayBoolean(values)) {
579 std::vector<int64> ones;
580 int first_zero = -1;
581 for (int i = 0; i < values.size(); ++i) {
582 if (values[i] == 1) {
583 ones.push_back(i);
584 } else {
585 first_zero = i;
586 }
587 }
588 if (ones.size() == 1) {
589 DCHECK_EQ(int64{1}, values[ones.back()]);
590 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
591 return solver->MakeIsEqualCstVar(index, ones.back());
592 } else if (ones.size() == values.size() - 1) {
593 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
594 return solver->MakeIsDifferentCstVar(index, first_zero);
595 } else if (ones.size() == ones.back() - ones.front() + 1) { // contiguous.
596 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
597 IntVar* const b = solver->MakeBoolVar("ContiguousBooleanElementVar");
598 solver->AddConstraint(
599 solver->MakeIsBetweenCt(index, ones.front(), ones.back(), b));
600 return b;
601 } else {
602 IntVar* const b = solver->MakeBoolVar("NonContiguousBooleanElementVar");
603 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
604 solver->AddConstraint(solver->MakeIsMemberCt(index, ones, b));
605 return b;
606 }
607 }
608 IntExpr* cache = nullptr;
609 if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
610 cache = solver->Cache()->FindVarConstantArrayExpression(
611 index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
612 }
613 if (cache != nullptr) {
614 return cache;
615 } else {
616 IntExpr* result = nullptr;
617 if (values.size() >= 2 && index->Min() == 0 && index->Max() == 1) {
618 result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),
619 values[0]);
620 } else if (values.size() == 2 && index->Contains(0) && index->Contains(1)) {
621 solver->AddConstraint(solver->MakeBetweenCt(index, 0, 1));
622 result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),
623 values[0]);
624 } else if (IsIncreasingContiguous(values)) {
625 result = solver->MakeSum(index, values[0]);
626 } else if (IsIncreasing(values)) {
627 result = solver->RegisterIntExpr(solver->RevAlloc(
628 new IncreasingIntExprElement(solver, values, index)));
629 } else {
630 if (solver->parameters().use_element_rmq()) {
631 result = solver->RegisterIntExpr(solver->RevAlloc(
632 new RangeMinimumQueryExprElement(solver, values, index)));
633 } else {
634 result = solver->RegisterIntExpr(
635 solver->RevAlloc(new IntExprElement(solver, values, index)));
636 }
637 }
638 if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
639 solver->Cache()->InsertVarConstantArrayExpression(
640 result, index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
641 }
642 return result;
643 }
644}
645} // namespace
646
647IntExpr* Solver::MakeElement(const std::vector<int64>& values,
648 IntVar* const index) {
649 DCHECK(index);
650 DCHECK_EQ(this, index->solver());
651 if (index->Bound()) {
652 return MakeIntConst(values[index->Min()]);
653 }
654 return BuildElement(this, values, index);
655}
656
657IntExpr* Solver::MakeElement(const std::vector<int>& values,
658 IntVar* const index) {
659 DCHECK(index);
660 DCHECK_EQ(this, index->solver());
661 if (index->Bound()) {
662 return MakeIntConst(values[index->Min()]);
663 }
664 return BuildElement(this, ToInt64Vector(values), index);
665}
666
667// ----- IntExprFunctionElement -----
668
669namespace {
670class IntExprFunctionElement : public BaseIntExprElement {
671 public:
672 IntExprFunctionElement(Solver* const s, Solver::IndexEvaluator1 values,
673 IntVar* const e);
674 ~IntExprFunctionElement() override;
675
676 std::string name() const override {
677 return absl::StrFormat("IntFunctionElement(%s)", expr_->name());
678 }
679
680 std::string DebugString() const override {
681 return absl::StrFormat("IntFunctionElement(%s)", expr_->DebugString());
682 }
683
684 void Accept(ModelVisitor* const visitor) const override {
685 // Warning: This will expand all values into a vector.
686 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
687 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
688 expr_);
689 visitor->VisitInt64ToInt64Extension(values_, expr_->Min(), expr_->Max());
690 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
691 }
692
693 protected:
694 int64 ElementValue(int index) const override { return values_(index); }
695 int64 ExprMin() const override { return expr_->Min(); }
696 int64 ExprMax() const override { return expr_->Max(); }
697
698 private:
699 Solver::IndexEvaluator1 values_;
700};
701
702IntExprFunctionElement::IntExprFunctionElement(Solver* const s,
703 Solver::IndexEvaluator1 values,
704 IntVar* const e)
705 : BaseIntExprElement(s, e), values_(std::move(values)) {
706 CHECK(values_ != nullptr);
707}
708
709IntExprFunctionElement::~IntExprFunctionElement() {}
710
711// ----- Increasing Element -----
712
713class IncreasingIntExprFunctionElement : public BaseIntExpr {
714 public:
715 IncreasingIntExprFunctionElement(Solver* const s,
717 IntVar* const index)
718 : BaseIntExpr(s), values_(std::move(values)), index_(index) {
719 DCHECK(values_ != nullptr);
720 DCHECK(index);
721 DCHECK(s);
722 }
723
724 ~IncreasingIntExprFunctionElement() override {}
725
726 int64 Min() const override { return values_(index_->Min()); }
727
728 void SetMin(int64 m) override {
729 const int64 index_min = index_->Min();
730 const int64 index_max = index_->Max();
731 if (m > values_(index_max)) {
732 solver()->Fail();
733 }
734 const int64 new_index_min = FindNewIndexMin(index_min, index_max, m);
735 index_->SetMin(new_index_min);
736 }
737
738 int64 Max() const override { return values_(index_->Max()); }
739
740 void SetMax(int64 m) override {
741 int64 index_min = index_->Min();
742 int64 index_max = index_->Max();
743 if (m < values_(index_min)) {
744 solver()->Fail();
745 }
746 const int64 new_index_max = FindNewIndexMax(index_min, index_max, m);
747 index_->SetMax(new_index_max);
748 }
749
750 void SetRange(int64 mi, int64 ma) override {
751 const int64 index_min = index_->Min();
752 const int64 index_max = index_->Max();
753 const int64 value_min = values_(index_min);
754 const int64 value_max = values_(index_max);
755 if (mi > ma || ma < value_min || mi > value_max) {
756 solver()->Fail();
757 }
758 if (mi <= value_min && ma >= value_max) {
759 // Nothing to do.
760 return;
761 }
762
763 const int64 new_index_min = FindNewIndexMin(index_min, index_max, mi);
764 const int64 new_index_max = FindNewIndexMax(new_index_min, index_max, ma);
765 // Assign.
766 index_->SetRange(new_index_min, new_index_max);
767 }
768
769 std::string name() const override {
770 return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",
771 index_->name());
772 }
773
774 std::string DebugString() const override {
775 return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",
776 index_->DebugString());
777 }
778
779 void WhenRange(Demon* d) override { index_->WhenRange(d); }
780
781 void Accept(ModelVisitor* const visitor) const override {
782 // Warning: This will expand all values into a vector.
783 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
784 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
785 index_);
786 if (index_->Min() == 0) {
787 visitor->VisitInt64ToInt64AsArray(values_, ModelVisitor::kValuesArgument,
788 index_->Max());
789 } else {
790 visitor->VisitInt64ToInt64Extension(values_, index_->Min(),
791 index_->Max());
792 }
793 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
794 }
795
796 private:
797 int64 FindNewIndexMin(int64 index_min, int64 index_max, int64 m) {
798 if (m <= values_(index_min)) {
799 return index_min;
800 }
801
802 DCHECK_LT(values_(index_min), m);
803 DCHECK_GE(values_(index_max), m);
804
805 int64 index_lower_bound = index_min;
806 int64 index_upper_bound = index_max;
807 while (index_upper_bound - index_lower_bound > 1) {
808 DCHECK_LT(values_(index_lower_bound), m);
809 DCHECK_GE(values_(index_upper_bound), m);
810 const int64 pivot = (index_lower_bound + index_upper_bound) / 2;
811 const int64 pivot_value = values_(pivot);
812 if (pivot_value < m) {
813 index_lower_bound = pivot;
814 } else {
815 index_upper_bound = pivot;
816 }
817 }
818 DCHECK(values_(index_upper_bound) >= m);
819 return index_upper_bound;
820 }
821
822 int64 FindNewIndexMax(int64 index_min, int64 index_max, int64 m) {
823 if (m >= values_(index_max)) {
824 return index_max;
825 }
826
827 DCHECK_LE(values_(index_min), m);
828 DCHECK_GT(values_(index_max), m);
829
830 int64 index_lower_bound = index_min;
831 int64 index_upper_bound = index_max;
832 while (index_upper_bound - index_lower_bound > 1) {
833 DCHECK_LE(values_(index_lower_bound), m);
834 DCHECK_GT(values_(index_upper_bound), m);
835 const int64 pivot = (index_lower_bound + index_upper_bound) / 2;
836 const int64 pivot_value = values_(pivot);
837 if (pivot_value > m) {
838 index_upper_bound = pivot;
839 } else {
840 index_lower_bound = pivot;
841 }
842 }
843 DCHECK(values_(index_lower_bound) <= m);
844 return index_lower_bound;
845 }
846
848 IntVar* const index_;
849};
850} // namespace
851
853 IntVar* const index) {
854 CHECK_EQ(this, index->solver());
855 return RegisterIntExpr(
856 RevAlloc(new IntExprFunctionElement(this, std::move(values), index)));
857}
858
860 bool increasing, IntVar* const index) {
861 CHECK_EQ(this, index->solver());
862 if (increasing) {
863 return RegisterIntExpr(
864 RevAlloc(new IncreasingIntExprFunctionElement(this, values, index)));
865 } else {
866 // You need to pass by copy such that opposite_value does not include a
867 // dandling reference when leaving this scope.
868 Solver::IndexEvaluator1 opposite_values = [values](int64 i) {
869 return -values(i);
870 };
872 new IncreasingIntExprFunctionElement(this, opposite_values, index))));
873 }
874}
875
876// ----- IntIntExprFunctionElement -----
877
878namespace {
879class IntIntExprFunctionElement : public BaseIntExpr {
880 public:
881 IntIntExprFunctionElement(Solver* const s, Solver::IndexEvaluator2 values,
882 IntVar* const expr1, IntVar* const expr2);
883 ~IntIntExprFunctionElement() override;
884 std::string DebugString() const override {
885 return absl::StrFormat("IntIntFunctionElement(%s,%s)",
886 expr1_->DebugString(), expr2_->DebugString());
887 }
888 int64 Min() const override;
889 int64 Max() const override;
890 void Range(int64* lower_bound, int64* upper_bound) override;
891 void SetMin(int64 lower_bound) override;
892 void SetMax(int64 upper_bound) override;
893 void SetRange(int64 lower_bound, int64 upper_bound) override;
894 bool Bound() const override { return expr1_->Bound() && expr2_->Bound(); }
895 // TODO(user) : improve me, the previous test is not always true
896 void WhenRange(Demon* d) override {
897 expr1_->WhenRange(d);
898 expr2_->WhenRange(d);
899 }
900
901 void Accept(ModelVisitor* const visitor) const override {
902 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
903 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
904 expr1_);
905 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndex2Argument,
906 expr2_);
907 // Warning: This will expand all values into a vector.
908 const int64 expr1_min = expr1_->Min();
909 const int64 expr1_max = expr1_->Max();
910 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, expr1_min);
911 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, expr1_max);
912 for (int i = expr1_min; i <= expr1_max; ++i) {
913 visitor->VisitInt64ToInt64Extension(
914 [this, i](int64 j) { return values_(i, j); }, expr2_->Min(),
915 expr2_->Max());
916 }
917 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
918 }
919
920 private:
921 int64 ElementValue(int index1, int index2) const {
922 return values_(index1, index2);
923 }
924 void UpdateSupports() const;
925
926 IntVar* const expr1_;
927 IntVar* const expr2_;
928 mutable int64 min_;
929 mutable int min_support1_;
930 mutable int min_support2_;
931 mutable int64 max_;
932 mutable int max_support1_;
933 mutable int max_support2_;
934 mutable bool initial_update_;
936 IntVarIterator* const expr1_iterator_;
937 IntVarIterator* const expr2_iterator_;
938};
939
940IntIntExprFunctionElement::IntIntExprFunctionElement(
941 Solver* const s, Solver::IndexEvaluator2 values, IntVar* const expr1,
942 IntVar* const expr2)
943 : BaseIntExpr(s),
944 expr1_(expr1),
945 expr2_(expr2),
946 min_(0),
947 min_support1_(-1),
948 min_support2_(-1),
949 max_(0),
950 max_support1_(-1),
951 max_support2_(-1),
952 initial_update_(true),
953 values_(std::move(values)),
954 expr1_iterator_(expr1_->MakeDomainIterator(true)),
955 expr2_iterator_(expr2_->MakeDomainIterator(true)) {
956 CHECK(values_ != nullptr);
957}
958
959IntIntExprFunctionElement::~IntIntExprFunctionElement() {}
960
961int64 IntIntExprFunctionElement::Min() const {
962 UpdateSupports();
963 return min_;
964}
965
966int64 IntIntExprFunctionElement::Max() const {
967 UpdateSupports();
968 return max_;
969}
970
971void IntIntExprFunctionElement::Range(int64* lower_bound, int64* upper_bound) {
972 UpdateSupports();
973 *lower_bound = min_;
974 *upper_bound = max_;
975}
976
977#define UPDATE_ELEMENT_INDEX_BOUNDS(test) \
978 const int64 emin1 = expr1_->Min(); \
979 const int64 emax1 = expr1_->Max(); \
980 const int64 emin2 = expr2_->Min(); \
981 const int64 emax2 = expr2_->Max(); \
982 int64 nmin1 = emin1; \
983 bool found = false; \
984 while (nmin1 <= emax1 && !found) { \
985 for (int i = emin2; i <= emax2; ++i) { \
986 int64 value = ElementValue(nmin1, i); \
987 if (test) { \
988 found = true; \
989 break; \
990 } \
991 } \
992 if (!found) { \
993 nmin1++; \
994 } \
995 } \
996 if (nmin1 > emax1) { \
997 solver()->Fail(); \
998 } \
999 int64 nmin2 = emin2; \
1000 found = false; \
1001 while (nmin2 <= emax2 && !found) { \
1002 for (int i = emin1; i <= emax1; ++i) { \
1003 int64 value = ElementValue(i, nmin2); \
1004 if (test) { \
1005 found = true; \
1006 break; \
1007 } \
1008 } \
1009 if (!found) { \
1010 nmin2++; \
1011 } \
1012 } \
1013 if (nmin2 > emax2) { \
1014 solver()->Fail(); \
1015 } \
1016 int64 nmax1 = emax1; \
1017 found = false; \
1018 while (nmax1 >= nmin1 && !found) { \
1019 for (int i = emin2; i <= emax2; ++i) { \
1020 int64 value = ElementValue(nmax1, i); \
1021 if (test) { \
1022 found = true; \
1023 break; \
1024 } \
1025 } \
1026 if (!found) { \
1027 nmax1--; \
1028 } \
1029 } \
1030 int64 nmax2 = emax2; \
1031 found = false; \
1032 while (nmax2 >= nmin2 && !found) { \
1033 for (int i = emin1; i <= emax1; ++i) { \
1034 int64 value = ElementValue(i, nmax2); \
1035 if (test) { \
1036 found = true; \
1037 break; \
1038 } \
1039 } \
1040 if (!found) { \
1041 nmax2--; \
1042 } \
1043 } \
1044 expr1_->SetRange(nmin1, nmax1); \
1045 expr2_->SetRange(nmin2, nmax2);
1046
1047void IntIntExprFunctionElement::SetMin(int64 lower_bound) {
1048 UPDATE_ELEMENT_INDEX_BOUNDS(value >= lower_bound);
1049}
1050
1051void IntIntExprFunctionElement::SetMax(int64 upper_bound) {
1052 UPDATE_ELEMENT_INDEX_BOUNDS(value <= upper_bound);
1053}
1054
1055void IntIntExprFunctionElement::SetRange(int64 lower_bound, int64 upper_bound) {
1056 if (lower_bound > upper_bound) {
1057 solver()->Fail();
1058 }
1059 UPDATE_ELEMENT_INDEX_BOUNDS(value >= lower_bound && value <= upper_bound);
1060}
1061
1062#undef UPDATE_ELEMENT_INDEX_BOUNDS
1063
1064void IntIntExprFunctionElement::UpdateSupports() const {
1065 if (initial_update_ || !expr1_->Contains(min_support1_) ||
1066 !expr1_->Contains(max_support1_) || !expr2_->Contains(min_support2_) ||
1067 !expr2_->Contains(max_support2_)) {
1068 const int64 emax1 = expr1_->Max();
1069 const int64 emax2 = expr2_->Max();
1070 int64 min_value = ElementValue(emax1, emax2);
1071 int64 max_value = min_value;
1072 int min_support1 = emax1;
1073 int max_support1 = emax1;
1074 int min_support2 = emax2;
1075 int max_support2 = emax2;
1076 for (const int64 index1 : InitAndGetValues(expr1_iterator_)) {
1077 for (const int64 index2 : InitAndGetValues(expr2_iterator_)) {
1078 const int64 value = ElementValue(index1, index2);
1079 if (value > max_value) {
1080 max_value = value;
1081 max_support1 = index1;
1082 max_support2 = index2;
1083 } else if (value < min_value) {
1084 min_value = value;
1085 min_support1 = index1;
1086 min_support2 = index2;
1087 }
1088 }
1089 }
1090 Solver* s = solver();
1091 s->SaveAndSetValue(&min_, min_value);
1092 s->SaveAndSetValue(&min_support1_, min_support1);
1093 s->SaveAndSetValue(&min_support2_, min_support2);
1094 s->SaveAndSetValue(&max_, max_value);
1095 s->SaveAndSetValue(&max_support1_, max_support1);
1096 s->SaveAndSetValue(&max_support2_, max_support2);
1097 s->SaveAndSetValue(&initial_update_, false);
1098 }
1099}
1100} // namespace
1101
1103 IntVar* const index1, IntVar* const index2) {
1104 CHECK_EQ(this, index1->solver());
1105 CHECK_EQ(this, index2->solver());
1107 new IntIntExprFunctionElement(this, std::move(values), index1, index2)));
1108}
1109
1110// ---------- Generalized element ----------
1111
1112// ----- IfThenElseCt -----
1113
1115 public:
1116 IfThenElseCt(Solver* const solver, IntVar* const condition,
1117 IntExpr* const one, IntExpr* const zero, IntVar* const target)
1118 : CastConstraint(solver, target),
1119 condition_(condition),
1120 zero_(zero),
1121 one_(one) {}
1122
1123 ~IfThenElseCt() override {}
1124
1125 void Post() override {
1126 Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);
1127 condition_->WhenBound(demon);
1128 one_->WhenRange(demon);
1129 zero_->WhenRange(demon);
1130 target_var_->WhenRange(demon);
1131 }
1132
1133 void InitialPropagate() override {
1134 condition_->SetRange(0, 1);
1135 const int64 target_var_min = target_var_->Min();
1136 const int64 target_var_max = target_var_->Max();
1137 int64 new_min = kint64min;
1138 int64 new_max = kint64max;
1139 if (condition_->Max() == 0) {
1140 zero_->SetRange(target_var_min, target_var_max);
1141 zero_->Range(&new_min, &new_max);
1142 } else if (condition_->Min() == 1) {
1143 one_->SetRange(target_var_min, target_var_max);
1144 one_->Range(&new_min, &new_max);
1145 } else {
1146 if (target_var_max < zero_->Min() || target_var_min > zero_->Max()) {
1147 condition_->SetValue(1);
1148 one_->SetRange(target_var_min, target_var_max);
1149 one_->Range(&new_min, &new_max);
1150 } else if (target_var_max < one_->Min() || target_var_min > one_->Max()) {
1151 condition_->SetValue(0);
1152 zero_->SetRange(target_var_min, target_var_max);
1153 zero_->Range(&new_min, &new_max);
1154 } else {
1155 int64 zl = 0;
1156 int64 zu = 0;
1157 int64 ol = 0;
1158 int64 ou = 0;
1159 zero_->Range(&zl, &zu);
1160 one_->Range(&ol, &ou);
1161 new_min = std::min(zl, ol);
1162 new_max = std::max(zu, ou);
1163 }
1164 }
1165 target_var_->SetRange(new_min, new_max);
1166 }
1167
1168 std::string DebugString() const override {
1169 return absl::StrFormat("(%s ? %s : %s) == %s", condition_->DebugString(),
1170 one_->DebugString(), zero_->DebugString(),
1172 }
1173
1174 void Accept(ModelVisitor* const visitor) const override {}
1175
1176 private:
1177 IntVar* const condition_;
1178 IntExpr* const zero_;
1179 IntExpr* const one_;
1180};
1181
1182// ----- IntExprEvaluatorElementCt -----
1183
1184// This constraint implements evaluator(index) == var. It is delayed such
1185// that propagation only occurs when all variables have been touched.
1186// The range of the evaluator is [range_start, range_end).
1187
1188namespace {
1189class IntExprEvaluatorElementCt : public CastConstraint {
1190 public:
1191 IntExprEvaluatorElementCt(Solver* const s, Solver::Int64ToIntVar evaluator,
1192 int64 range_start, int64 range_end,
1193 IntVar* const index, IntVar* const target_var);
1194 ~IntExprEvaluatorElementCt() override {}
1195
1196 void Post() override;
1197 void InitialPropagate() override;
1198
1199 void Propagate();
1200 void Update(int index);
1201 void UpdateExpr();
1202
1203 std::string DebugString() const override;
1204 void Accept(ModelVisitor* const visitor) const override;
1205
1206 protected:
1207 IntVar* const index_;
1208
1209 private:
1211 const int64 range_start_;
1212 const int64 range_end_;
1213 int min_support_;
1214 int max_support_;
1215};
1216
1217IntExprEvaluatorElementCt::IntExprEvaluatorElementCt(
1218 Solver* const s, Solver::Int64ToIntVar evaluator, int64 range_start,
1219 int64 range_end, IntVar* const index, IntVar* const target_var)
1220 : CastConstraint(s, target_var),
1221 index_(index),
1222 evaluator_(std::move(evaluator)),
1223 range_start_(range_start),
1224 range_end_(range_end),
1225 min_support_(-1),
1226 max_support_(-1) {}
1227
1228void IntExprEvaluatorElementCt::Post() {
1229 Demon* const delayed_propagate_demon = MakeDelayedConstraintDemon0(
1230 solver(), this, &IntExprEvaluatorElementCt::Propagate, "Propagate");
1231 for (int i = range_start_; i < range_end_; ++i) {
1232 IntVar* const current_var = evaluator_(i);
1233 current_var->WhenRange(delayed_propagate_demon);
1234 Demon* const update_demon = MakeConstraintDemon1(
1235 solver(), this, &IntExprEvaluatorElementCt::Update, "Update", i);
1236 current_var->WhenRange(update_demon);
1237 }
1238 index_->WhenRange(delayed_propagate_demon);
1239 Demon* const update_expr_demon = MakeConstraintDemon0(
1240 solver(), this, &IntExprEvaluatorElementCt::UpdateExpr, "UpdateExpr");
1241 index_->WhenRange(update_expr_demon);
1242 Demon* const update_var_demon = MakeConstraintDemon0(
1243 solver(), this, &IntExprEvaluatorElementCt::Propagate, "UpdateVar");
1244
1245 target_var_->WhenRange(update_var_demon);
1246}
1247
1248void IntExprEvaluatorElementCt::InitialPropagate() { Propagate(); }
1249
1250void IntExprEvaluatorElementCt::Propagate() {
1251 const int64 emin = std::max(range_start_, index_->Min());
1252 const int64 emax = std::min<int64>(range_end_ - 1, index_->Max());
1253 const int64 vmin = target_var_->Min();
1254 const int64 vmax = target_var_->Max();
1255 if (emin == emax) {
1256 index_->SetValue(emin); // in case it was reduced by the above min/max.
1257 evaluator_(emin)->SetRange(vmin, vmax);
1258 } else {
1259 int64 nmin = emin;
1260 for (; nmin <= emax; nmin++) {
1261 // break if the intersection of
1262 // [evaluator_(nmin)->Min(), evaluator_(nmin)->Max()] and [vmin, vmax]
1263 // is non-empty.
1264 IntVar* const nmin_var = evaluator_(nmin);
1265 if (nmin_var->Min() <= vmax && nmin_var->Max() >= vmin) break;
1266 }
1267 int64 nmax = emax;
1268 for (; nmin <= nmax; nmax--) {
1269 // break if the intersection of
1270 // [evaluator_(nmin)->Min(), evaluator_(nmin)->Max()] and [vmin, vmax]
1271 // is non-empty.
1272 IntExpr* const nmax_var = evaluator_(nmax);
1273 if (nmax_var->Min() <= vmax && nmax_var->Max() >= vmin) break;
1274 }
1275 index_->SetRange(nmin, nmax);
1276 if (nmin == nmax) {
1277 evaluator_(nmin)->SetRange(vmin, vmax);
1278 }
1279 }
1280 if (min_support_ == -1 || max_support_ == -1) {
1281 int min_support = -1;
1282 int max_support = -1;
1283 int64 gmin = kint64max;
1284 int64 gmax = kint64min;
1285 for (int i = index_->Min(); i <= index_->Max(); ++i) {
1286 IntExpr* const var_i = evaluator_(i);
1287 const int64 vmin = var_i->Min();
1288 if (vmin < gmin) {
1289 gmin = vmin;
1290 }
1291 const int64 vmax = var_i->Max();
1292 if (vmax > gmax) {
1293 gmax = vmax;
1294 }
1295 }
1296 solver()->SaveAndSetValue(&min_support_, min_support);
1297 solver()->SaveAndSetValue(&max_support_, max_support);
1298 target_var_->SetRange(gmin, gmax);
1299 }
1300}
1301
1302void IntExprEvaluatorElementCt::Update(int index) {
1303 if (index == min_support_ || index == max_support_) {
1304 solver()->SaveAndSetValue(&min_support_, -1);
1305 solver()->SaveAndSetValue(&max_support_, -1);
1306 }
1307}
1308
1309void IntExprEvaluatorElementCt::UpdateExpr() {
1310 if (!index_->Contains(min_support_) || !index_->Contains(max_support_)) {
1311 solver()->SaveAndSetValue(&min_support_, -1);
1312 solver()->SaveAndSetValue(&max_support_, -1);
1313 }
1314}
1315
1316namespace {
1317std::string StringifyEvaluatorBare(const Solver::Int64ToIntVar& evaluator,
1318 int64 range_start, int64 range_end) {
1319 std::string out;
1320 for (int64 i = range_start; i < range_end; ++i) {
1321 if (i != range_start) {
1322 out += ", ";
1323 }
1324 out += absl::StrFormat("%d -> %s", i, evaluator(i)->DebugString());
1325 }
1326 return out;
1327}
1328
1329std::string StringifyInt64ToIntVar(const Solver::Int64ToIntVar& evaluator,
1330 int64 range_begin, int64 range_end) {
1331 std::string out;
1332 if (range_end - range_begin > 10) {
1333 out = absl::StrFormat(
1334 "IntToIntVar(%s, ...%s)",
1335 StringifyEvaluatorBare(evaluator, range_begin, range_begin + 5),
1336 StringifyEvaluatorBare(evaluator, range_end - 5, range_end));
1337 } else {
1338 out = absl::StrFormat(
1339 "IntToIntVar(%s)",
1340 StringifyEvaluatorBare(evaluator, range_begin, range_end));
1341 }
1342 return out;
1343}
1344} // namespace
1345
1346std::string IntExprEvaluatorElementCt::DebugString() const {
1347 return StringifyInt64ToIntVar(evaluator_, range_start_, range_end_);
1348}
1349
1350void IntExprEvaluatorElementCt::Accept(ModelVisitor* const visitor) const {
1351 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1352 visitor->VisitIntegerVariableEvaluatorArgument(
1354 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);
1355 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1356 target_var_);
1357 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1358}
1359
1360// ----- IntExprArrayElementCt -----
1361
1362// This constraint implements vars[index] == var. It is delayed such
1363// that propagation only occurs when all variables have been touched.
1364
1365class IntExprArrayElementCt : public IntExprEvaluatorElementCt {
1366 public:
1367 IntExprArrayElementCt(Solver* const s, std::vector<IntVar*> vars,
1368 IntVar* const index, IntVar* const target_var);
1369
1370 std::string DebugString() const override;
1371 void Accept(ModelVisitor* const visitor) const override;
1372
1373 private:
1374 const std::vector<IntVar*> vars_;
1375};
1376
1377IntExprArrayElementCt::IntExprArrayElementCt(Solver* const s,
1378 std::vector<IntVar*> vars,
1379 IntVar* const index,
1380 IntVar* const target_var)
1381 : IntExprEvaluatorElementCt(
1382 s, [this](int64 idx) { return vars_[idx]; }, 0, vars.size(), index,
1383 target_var),
1384 vars_(std::move(vars)) {}
1385
1386std::string IntExprArrayElementCt::DebugString() const {
1387 int64 size = vars_.size();
1388 if (size > 10) {
1389 return absl::StrFormat(
1390 "IntExprArrayElement(var array of size %d, %s) == %s", size,
1391 index_->DebugString(), target_var_->DebugString());
1392 } else {
1393 return absl::StrFormat("IntExprArrayElement([%s], %s) == %s",
1394 JoinDebugStringPtr(vars_, ", "),
1395 index_->DebugString(), target_var_->DebugString());
1396 }
1397}
1398
1399void IntExprArrayElementCt::Accept(ModelVisitor* const visitor) const {
1400 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1401 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1402 vars_);
1403 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);
1404 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1405 target_var_);
1406 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1407}
1408
1409// ----- IntExprArrayElementCstCt -----
1410
1411// This constraint implements vars[index] == constant.
1412
1413class IntExprArrayElementCstCt : public Constraint {
1414 public:
1415 IntExprArrayElementCstCt(Solver* const s, const std::vector<IntVar*>& vars,
1416 IntVar* const index, int64 target)
1417 : Constraint(s),
1418 vars_(vars),
1419 index_(index),
1420 target_(target),
1421 demons_(vars.size()) {}
1422
1423 ~IntExprArrayElementCstCt() override {}
1424
1425 void Post() override {
1426 for (int i = 0; i < vars_.size(); ++i) {
1427 demons_[i] = MakeConstraintDemon1(
1428 solver(), this, &IntExprArrayElementCstCt::Propagate, "Propagate", i);
1429 vars_[i]->WhenDomain(demons_[i]);
1430 }
1431 Demon* const index_demon = MakeConstraintDemon0(
1432 solver(), this, &IntExprArrayElementCstCt::PropagateIndex,
1433 "PropagateIndex");
1434 index_->WhenBound(index_demon);
1435 }
1436
1437 void InitialPropagate() override {
1438 for (int i = 0; i < vars_.size(); ++i) {
1439 Propagate(i);
1440 }
1441 PropagateIndex();
1442 }
1443
1444 void Propagate(int index) {
1445 if (!vars_[index]->Contains(target_)) {
1446 index_->RemoveValue(index);
1447 demons_[index]->inhibit(solver());
1448 }
1449 }
1450
1451 void PropagateIndex() {
1452 if (index_->Bound()) {
1453 vars_[index_->Min()]->SetValue(target_);
1454 }
1455 }
1456
1457 std::string DebugString() const override {
1458 return absl::StrFormat("IntExprArrayElement([%s], %s) == %d",
1459 JoinDebugStringPtr(vars_, ", "),
1460 index_->DebugString(), target_);
1461 }
1462
1463 void Accept(ModelVisitor* const visitor) const override {
1464 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1465 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1466 vars_);
1467 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
1468 index_);
1469 visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);
1470 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1471 }
1472
1473 private:
1474 const std::vector<IntVar*> vars_;
1475 IntVar* const index_;
1476 const int64 target_;
1477 std::vector<Demon*> demons_;
1478};
1479
1480// This constraint implements index == position(constant in vars).
1481
1482class IntExprIndexOfCt : public Constraint {
1483 public:
1484 IntExprIndexOfCt(Solver* const s, const std::vector<IntVar*>& vars,
1485 IntVar* const index, int64 target)
1486 : Constraint(s),
1487 vars_(vars),
1488 index_(index),
1489 target_(target),
1490 demons_(vars_.size()),
1491 index_iterator_(index->MakeHoleIterator(true)) {}
1492
1493 ~IntExprIndexOfCt() override {}
1494
1495 void Post() override {
1496 for (int i = 0; i < vars_.size(); ++i) {
1497 demons_[i] = MakeConstraintDemon1(
1498 solver(), this, &IntExprIndexOfCt::Propagate, "Propagate", i);
1499 vars_[i]->WhenDomain(demons_[i]);
1500 }
1501 Demon* const index_demon = MakeConstraintDemon0(
1502 solver(), this, &IntExprIndexOfCt::PropagateIndex, "PropagateIndex");
1503 index_->WhenDomain(index_demon);
1504 }
1505
1506 void InitialPropagate() override {
1507 for (int i = 0; i < vars_.size(); ++i) {
1508 if (!index_->Contains(i)) {
1509 vars_[i]->RemoveValue(target_);
1510 } else if (!vars_[i]->Contains(target_)) {
1511 index_->RemoveValue(i);
1512 demons_[i]->inhibit(solver());
1513 } else if (vars_[i]->Bound()) {
1514 index_->SetValue(i);
1515 demons_[i]->inhibit(solver());
1516 }
1517 }
1518 }
1519
1520 void Propagate(int index) {
1521 if (!vars_[index]->Contains(target_)) {
1522 index_->RemoveValue(index);
1523 demons_[index]->inhibit(solver());
1524 } else if (vars_[index]->Bound()) {
1525 index_->SetValue(index);
1526 }
1527 }
1528
1529 void PropagateIndex() {
1530 const int64 oldmax = index_->OldMax();
1531 const int64 vmin = index_->Min();
1532 const int64 vmax = index_->Max();
1533 for (int64 value = index_->OldMin(); value < vmin; ++value) {
1534 vars_[value]->RemoveValue(target_);
1535 demons_[value]->inhibit(solver());
1536 }
1537 for (const int64 value : InitAndGetValues(index_iterator_)) {
1538 vars_[value]->RemoveValue(target_);
1539 demons_[value]->inhibit(solver());
1540 }
1541 for (int64 value = vmax + 1; value <= oldmax; ++value) {
1542 vars_[value]->RemoveValue(target_);
1543 demons_[value]->inhibit(solver());
1544 }
1545 if (index_->Bound()) {
1546 vars_[index_->Min()]->SetValue(target_);
1547 }
1548 }
1549
1550 std::string DebugString() const override {
1551 return absl::StrFormat("IntExprIndexOf([%s], %s) == %d",
1552 JoinDebugStringPtr(vars_, ", "),
1553 index_->DebugString(), target_);
1554 }
1555
1556 void Accept(ModelVisitor* const visitor) const override {
1557 visitor->BeginVisitConstraint(ModelVisitor::kIndexOf, this);
1558 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1559 vars_);
1560 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
1561 index_);
1562 visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);
1563 visitor->EndVisitConstraint(ModelVisitor::kIndexOf, this);
1564 }
1565
1566 private:
1567 const std::vector<IntVar*> vars_;
1568 IntVar* const index_;
1569 const int64 target_;
1570 std::vector<Demon*> demons_;
1571 IntVarIterator* const index_iterator_;
1572};
1573
1574// Factory helper.
1575
1576Constraint* MakeElementEqualityFunc(Solver* const solver,
1577 const std::vector<int64>& vals,
1578 IntVar* const index, IntVar* const target) {
1579 if (index->Bound()) {
1580 const int64 val = index->Min();
1581 if (val < 0 || val >= vals.size()) {
1582 return solver->MakeFalseConstraint();
1583 } else {
1584 return solver->MakeEquality(target, vals[val]);
1585 }
1586 } else {
1587 if (IsIncreasingContiguous(vals)) {
1588 return solver->MakeEquality(target, solver->MakeSum(index, vals[0]));
1589 } else {
1590 return solver->RevAlloc(
1591 new IntElementConstraint(solver, vals, index, target));
1592 }
1593 }
1594}
1595} // namespace
1596
1598 IntExpr* const then_expr,
1599 IntExpr* const else_expr,
1600 IntVar* const target_var) {
1601 return RevAlloc(
1602 new IfThenElseCt(this, condition, then_expr, else_expr, target_var));
1603}
1604
1605IntExpr* Solver::MakeElement(const std::vector<IntVar*>& vars,
1606 IntVar* const index) {
1607 if (index->Bound()) {
1608 return vars[index->Min()];
1609 }
1610 const int size = vars.size();
1611 if (AreAllBound(vars)) {
1612 std::vector<int64> values(size);
1613 for (int i = 0; i < size; ++i) {
1614 values[i] = vars[i]->Value();
1615 }
1616 return MakeElement(values, index);
1617 }
1618 if (index->Size() == 2 && index->Min() + 1 == index->Max() &&
1619 index->Min() >= 0 && index->Max() < vars.size()) {
1620 // Let's get the index between 0 and 1.
1621 IntVar* const scaled_index = MakeSum(index, -index->Min())->Var();
1622 IntVar* const zero = vars[index->Min()];
1623 IntVar* const one = vars[index->Max()];
1624 const std::string name = absl::StrFormat(
1625 "ElementVar([%s], %s)", JoinNamePtr(vars, ", "), index->name());
1626 IntVar* const target = MakeIntVar(std::min(zero->Min(), one->Min()),
1627 std::max(zero->Max(), one->Max()), name);
1629 RevAlloc(new IfThenElseCt(this, scaled_index, one, zero, target)));
1630 return target;
1631 }
1632 int64 emin = kint64max;
1633 int64 emax = kint64min;
1634 std::unique_ptr<IntVarIterator> iterator(index->MakeDomainIterator(false));
1635 for (const int64 index_value : InitAndGetValues(iterator.get())) {
1636 if (index_value >= 0 && index_value < size) {
1637 emin = std::min(emin, vars[index_value]->Min());
1638 emax = std::max(emax, vars[index_value]->Max());
1639 }
1640 }
1641 const std::string vname =
1642 size > 10 ? absl::StrFormat("ElementVar(var array of size %d, %s)", size,
1643 index->DebugString())
1644 : absl::StrFormat("ElementVar([%s], %s)",
1645 JoinNamePtr(vars, ", "), index->name());
1646 IntVar* const element_var = MakeIntVar(emin, emax, vname);
1648 RevAlloc(new IntExprArrayElementCt(this, vars, index, element_var)));
1649 return element_var;
1650}
1651
1653 int64 range_end, IntVar* argument) {
1654 const std::string index_name =
1655 !argument->name().empty() ? argument->name() : argument->DebugString();
1656 const std::string vname = absl::StrFormat(
1657 "ElementVar(%s, %s)",
1658 StringifyInt64ToIntVar(vars, range_start, range_end), index_name);
1659 IntVar* const element_var = MakeIntVar(kint64min, kint64max, vname);
1660 IntExprEvaluatorElementCt* evaluation_ct = new IntExprEvaluatorElementCt(
1661 this, std::move(vars), range_start, range_end, argument, element_var);
1662 AddConstraint(RevAlloc(evaluation_ct));
1663 evaluation_ct->Propagate();
1664 return element_var;
1665}
1666
1667Constraint* Solver::MakeElementEquality(const std::vector<int64>& vals,
1668 IntVar* const index,
1669 IntVar* const target) {
1670 return MakeElementEqualityFunc(this, vals, index, target);
1671}
1672
1673Constraint* Solver::MakeElementEquality(const std::vector<int>& vals,
1674 IntVar* const index,
1675 IntVar* const target) {
1676 return MakeElementEqualityFunc(this, ToInt64Vector(vals), index, target);
1677}
1678
1679Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
1680 IntVar* const index,
1681 IntVar* const target) {
1682 if (AreAllBound(vars)) {
1683 std::vector<int64> values(vars.size());
1684 for (int i = 0; i < vars.size(); ++i) {
1685 values[i] = vars[i]->Value();
1686 }
1687 return MakeElementEquality(values, index, target);
1688 }
1689 if (index->Bound()) {
1690 const int64 val = index->Min();
1691 if (val < 0 || val >= vars.size()) {
1692 return MakeFalseConstraint();
1693 } else {
1694 return MakeEquality(target, vars[val]);
1695 }
1696 } else {
1697 if (target->Bound()) {
1698 return RevAlloc(
1699 new IntExprArrayElementCstCt(this, vars, index, target->Min()));
1700 } else {
1701 return RevAlloc(new IntExprArrayElementCt(this, vars, index, target));
1702 }
1703 }
1704}
1705
1706Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
1707 IntVar* const index, int64 target) {
1708 if (AreAllBound(vars)) {
1709 std::vector<int> valid_indices;
1710 for (int i = 0; i < vars.size(); ++i) {
1711 if (vars[i]->Value() == target) {
1712 valid_indices.push_back(i);
1713 }
1714 }
1715 return MakeMemberCt(index, valid_indices);
1716 }
1717 if (index->Bound()) {
1718 const int64 pos = index->Min();
1719 if (pos >= 0 && pos < vars.size()) {
1720 IntVar* const var = vars[pos];
1721 return MakeEquality(var, target);
1722 } else {
1723 return MakeFalseConstraint();
1724 }
1725 } else {
1726 return RevAlloc(new IntExprArrayElementCstCt(this, vars, index, target));
1727 }
1728}
1729
1730Constraint* Solver::MakeIndexOfConstraint(const std::vector<IntVar*>& vars,
1731 IntVar* const index, int64 target) {
1732 if (index->Bound()) {
1733 const int64 pos = index->Min();
1734 if (pos >= 0 && pos < vars.size()) {
1735 IntVar* const var = vars[pos];
1736 return MakeEquality(var, target);
1737 } else {
1738 return MakeFalseConstraint();
1739 }
1740 } else {
1741 return RevAlloc(new IntExprIndexOfCt(this, vars, index, target));
1742 }
1743}
1744
1745IntExpr* Solver::MakeIndexExpression(const std::vector<IntVar*>& vars,
1746 int64 value) {
1747 IntExpr* const cache = model_cache_->FindVarArrayConstantExpression(
1749 if (cache != nullptr) {
1750 return cache->Var();
1751 } else {
1752 const std::string name =
1753 absl::StrFormat("Index(%s, %d)", JoinNamePtr(vars, ", "), value);
1754 IntVar* const index = MakeIntVar(0, vars.size() - 1, name);
1756 model_cache_->InsertVarArrayConstantExpression(
1758 return index;
1759 }
1760}
1761} // namespace operations_research
int64 min
Definition: alldiff_cst.cc:138
const std::vector< IntVar * > vars_
Definition: alldiff_cst.cc:43
int64 max
Definition: alldiff_cst.cc:139
#define CHECK(condition)
Definition: base/logging.h:495
#define DCHECK_LE(val1, val2)
Definition: base/logging.h:887
#define CHECK_EQ(val1, val2)
Definition: base/logging.h:697
#define DCHECK_GE(val1, val2)
Definition: base/logging.h:889
#define DCHECK_GT(val1, val2)
Definition: base/logging.h:890
#define DCHECK_LT(val1, val2)
Definition: base/logging.h:888
#define DCHECK(condition)
Definition: base/logging.h:884
#define DCHECK_EQ(val1, val2)
Definition: base/logging.h:885
This is the base class for all expressions that are not variables.
Cast constraints are special channeling constraints designed to keep a variable in sync with an expre...
A constraint is the main modeling object.
A Demon is the base element of a propagation queue.
void Post() override
This method is called when the constraint is processed by the solver.
Definition: element.cc:1125
void InitialPropagate() override
This method performs the initial propagation of the constraint.
Definition: element.cc:1133
IfThenElseCt(Solver *const solver, IntVar *const condition, IntExpr *const one, IntExpr *const zero, IntVar *const target)
Definition: element.cc:1116
void Accept(ModelVisitor *const visitor) const override
Accepts the given visitor.
Definition: element.cc:1174
std::string DebugString() const override
Definition: element.cc:1168
Utility class to encapsulate an IntVarIterator and use it in a range-based loop.
The class IntExpr is the base of all integer expressions in constraint programming.
virtual void SetRange(int64 l, int64 u)
This method sets both the min and the max of the expression.
virtual void SetValue(int64 v)
This method sets the value of the expression.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual void Range(int64 *l, int64 *u)
By default calls Min() and Max(), but can be redefined when Min and Max code can be factorized.
virtual int64 Max() const =0
virtual IntVar * Var()=0
Creates a variable from the expression.
virtual int64 Min() const =0
virtual void WhenRange(Demon *d)=0
Attach a demon that will watch the min or the max of the expression.
The class IntVar is a subset of IntExpr.
virtual void WhenBound(Demon *d)=0
This method attaches a demon that will be awakened when the variable is bound.
virtual bool Contains(int64 v) const =0
This method returns whether the value 'v' is in the domain of the variable.
virtual std::string name() const
Object naming.
IntExpr * MakeElement(const std::vector< int64 > &values, IntVar *const index)
values[index]
Definition: element.cc:647
Constraint * MakeMemberCt(IntExpr *const expr, const std::vector< int64 > &values)
expr in set.
Definition: expr_cst.cc:1160
IntExpr * RegisterIntExpr(IntExpr *const expr)
Registers a new IntExpr and wraps it inside a TraceIntExpr if necessary.
Definition: trace.cc:844
Constraint * MakeFalseConstraint()
This constraint always fails.
Definition: constraints.cc:520
Constraint * MakeEquality(IntExpr *const left, IntExpr *const right)
left == right
Definition: range_cst.cc:512
void AddConstraint(Constraint *const c)
Adds the constraint 'c' to the model.
IntExpr * MakeOpposite(IntExpr *const expr)
-expr
Constraint * MakeIfThenElseCt(IntVar *const condition, IntExpr *const then_expr, IntExpr *const else_expr, IntVar *const target_var)
Special cases with arrays of size two.
Definition: element.cc:1597
Demon * MakeConstraintInitialPropagateCallback(Constraint *const ct)
This method is a specialized case of the MakeConstraintDemon method to call the InitiatePropagate of ...
Definition: constraints.cc:33
std::function< int64(int64)> IndexEvaluator1
Callback typedefs.
std::function< int64(int64, int64)> IndexEvaluator2
IntExpr * MakeSum(IntExpr *const left, IntExpr *const right)
left + right.
IntExpr * MakeIndexExpression(const std::vector< IntVar * > &vars, int64 value)
Returns the expression expr such that vars[expr] == value.
Definition: element.cc:1745
IntVar * MakeIntVar(int64 min, int64 max, const std::string &name)
MakeIntVar will create the best range based int var for the bounds given.
Constraint * MakeIndexOfConstraint(const std::vector< IntVar * > &vars, IntVar *const index, int64 target)
This constraint is a special case of the element constraint with an array of integer variables,...
Definition: element.cc:1730
std::function< IntVar *(int64)> Int64ToIntVar
Constraint * MakeElementEquality(const std::vector< int64 > &vals, IntVar *const index, IntVar *const target)
Definition: element.cc:1667
T * RevAlloc(T *object)
Registers the given object as being reversible.
IntExpr * MakeMonotonicElement(IndexEvaluator1 values, bool increasing, IntVar *const index)
Function based element.
Definition: element.cc:859
std::vector< int64 > to_remove_
const std::string name
int64 value
#define UPDATE_ELEMENT_INDEX_BOUNDS(test)
Definition: element.cc:977
IntVar *const expr_
Definition: element.cc:85
ABSL_FLAG(bool, cp_disable_element_cache, true, "If true, caching for IntElement is disabled.")
#define UPDATE_BASE_ELEMENT_INDEX_BOUNDS(test)
Definition: element.cc:127
#define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test)
Definition: element.cc:417
IntVar * var
Definition: expr_array.cc:1858
static const int64 kint64max
int64_t int64
uint64_t uint64
static const int64 kint64min
std::function< int64(const Model &)> Value(IntegerVariable v)
Definition: integer.h:1487
The vehicle routing library lets one model and solve generic vehicle routing problems ranging from th...
std::string JoinNamePtr(const std::vector< T > &v, const std::string &separator)
Definition: string_array.h:52
Demon * MakeDelayedConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
bool AreAllBound(const std::vector< IntVar * > &vars)
bool IsArrayConstant(const std::vector< T > &values, const T &value)
Demon * MakeConstraintDemon1(Solver *const s, T *const ct, void(T::*method)(P), const std::string &name, P param1)
bool IsIncreasingContiguous(const std::vector< T > &values)
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
std::string JoinDebugStringPtr(const std::vector< T > &v, const std::string &separator)
Definition: string_array.h:45
std::vector< int64 > ToInt64Vector(const std::vector< int > &input)
Definition: utilities.cc:822
void LinkVarExpr(Solver *const s, IntExpr *const expr, IntVar *const var)
bool IsIncreasing(const std::vector< T > &values)
bool IsArrayBoolean(const std::vector< T > &values)
STL namespace.
int index
Definition: pack.cc:508
IntervalVar *const target_var_
std::function< int64(int64, int64)> evaluator_
Definition: search.cc:1361