OR-Tools  8.2
expr_array.cc
Go to the documentation of this file.
1// Copyright 2010-2018 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//
15// Array Expression constraints
16
17#include <algorithm>
18#include <cmath>
19#include <string>
20#include <vector>
21
22#include "absl/strings/str_format.h"
23#include "absl/strings/str_join.h"
31
32namespace operations_research {
33namespace {
34// ----- Tree Array Constraint -----
35
36class TreeArrayConstraint : public CastConstraint {
37 public:
38 TreeArrayConstraint(Solver* const solver, const std::vector<IntVar*>& vars,
39 IntVar* const sum_var)
40 : CastConstraint(solver, sum_var),
41 vars_(vars),
42 block_size_(solver->parameters().array_split_size()) {
43 std::vector<int> lengths;
44 lengths.push_back(vars_.size());
45 while (lengths.back() > 1) {
46 const int current = lengths.back();
47 lengths.push_back((current + block_size_ - 1) / block_size_);
48 }
49 tree_.resize(lengths.size());
50 for (int i = 0; i < lengths.size(); ++i) {
51 tree_[i].resize(lengths[lengths.size() - i - 1]);
52 }
53 DCHECK_GE(tree_.size(), 1);
54 DCHECK_EQ(1, tree_[0].size());
55 root_node_ = &tree_[0][0];
56 }
57
58 std::string DebugStringInternal(const std::string& name) const {
59 return absl::StrFormat("%s(%s) == %s", name,
61 target_var_->DebugString());
62 }
63
64 void AcceptInternal(const std::string& name,
65 ModelVisitor* const visitor) const {
66 visitor->BeginVisitConstraint(name, this);
67 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
68 vars_);
69 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
71 visitor->EndVisitConstraint(name, this);
72 }
73
74 // Increases min by delta_min, reduces max by delta_max.
75 void ReduceRange(int depth, int position, int64 delta_min, int64 delta_max) {
76 NodeInfo* const info = &tree_[depth][position];
77 if (delta_min > 0) {
78 info->node_min.SetValue(solver(),
79 CapAdd(info->node_min.Value(), delta_min));
80 }
81 if (delta_max > 0) {
82 info->node_max.SetValue(solver(),
83 CapSub(info->node_max.Value(), delta_max));
84 }
85 }
86
87 // Sets the range on the given node.
88 void SetRange(int depth, int position, int64 new_min, int64 new_max) {
89 NodeInfo* const info = &tree_[depth][position];
90 if (new_min > info->node_min.Value()) {
91 info->node_min.SetValue(solver(), new_min);
92 }
93 if (new_max < info->node_max.Value()) {
94 info->node_max.SetValue(solver(), new_max);
95 }
96 }
97
98 void InitLeaf(int position, int64 var_min, int64 var_max) {
99 InitNode(MaxDepth(), position, var_min, var_max);
100 }
101
102 void InitNode(int depth, int position, int64 node_min, int64 node_max) {
103 tree_[depth][position].node_min.SetValue(solver(), node_min);
104 tree_[depth][position].node_max.SetValue(solver(), node_max);
105 }
106
107 int64 Min(int depth, int position) const {
108 return tree_[depth][position].node_min.Value();
109 }
110
111 int64 Max(int depth, int position) const {
112 return tree_[depth][position].node_max.Value();
113 }
114
115 int64 RootMin() const { return root_node_->node_min.Value(); }
116
117 int64 RootMax() const { return root_node_->node_max.Value(); }
118
119 int Parent(int position) const { return position / block_size_; }
120
121 int ChildStart(int position) const { return position * block_size_; }
122
123 int ChildEnd(int depth, int position) const {
124 DCHECK_LT(depth + 1, tree_.size());
125 return std::min((position + 1) * block_size_ - 1, Width(depth + 1) - 1);
126 }
127
128 bool IsLeaf(int depth) const { return depth == MaxDepth(); }
129
130 int MaxDepth() const { return tree_.size() - 1; }
131
132 int Width(int depth) const { return tree_[depth].size(); }
133
134 protected:
135 const std::vector<IntVar*> vars_;
136
137 private:
138 struct NodeInfo {
139 NodeInfo() : node_min(0), node_max(0) {}
140 Rev<int64> node_min;
141 Rev<int64> node_max;
142 };
143
144 std::vector<std::vector<NodeInfo> > tree_;
145 const int block_size_;
146 NodeInfo* root_node_;
147};
148
149// ---------- Sum Array ----------
150
151// Some of these optimizations here are described in:
152// "Bounds consistency techniques for long linear constraints". In
153// Workshop on Techniques for Implementing Constraint Programming
154// Systems (TRICS), a workshop of CP 2002, N. Beldiceanu, W. Harvey,
155// Martin Henz, Francois Laburthe, Eric Monfroy, Tobias Müller,
156// Laurent Perron and Christian Schulte editors, pages 39-46, 2002.
157
158// ----- SumConstraint -----
159
160// This constraint implements sum(vars) == sum_var.
161class SumConstraint : public TreeArrayConstraint {
162 public:
163 SumConstraint(Solver* const solver, const std::vector<IntVar*>& vars,
164 IntVar* const sum_var)
165 : TreeArrayConstraint(solver, vars, sum_var), sum_demon_(nullptr) {}
166
167 ~SumConstraint() override {}
168
169 void Post() override {
170 for (int i = 0; i < vars_.size(); ++i) {
171 Demon* const demon = MakeConstraintDemon1(
172 solver(), this, &SumConstraint::LeafChanged, "LeafChanged", i);
173 vars_[i]->WhenRange(demon);
174 }
175 sum_demon_ = solver()->RegisterDemon(MakeDelayedConstraintDemon0(
176 solver(), this, &SumConstraint::SumChanged, "SumChanged"));
177 target_var_->WhenRange(sum_demon_);
178 }
179
180 void InitialPropagate() override {
181 // Copy vars to leaf nodes.
182 for (int i = 0; i < vars_.size(); ++i) {
183 InitLeaf(i, vars_[i]->Min(), vars_[i]->Max());
184 }
185 // Compute up.
186 for (int i = MaxDepth() - 1; i >= 0; --i) {
187 for (int j = 0; j < Width(i); ++j) {
188 int64 sum_min = 0;
189 int64 sum_max = 0;
190 const int block_start = ChildStart(j);
191 const int block_end = ChildEnd(i, j);
192 for (int k = block_start; k <= block_end; ++k) {
193 sum_min = CapAdd(sum_min, Min(i + 1, k));
194 sum_max = CapAdd(sum_max, Max(i + 1, k));
195 }
196 InitNode(i, j, sum_min, sum_max);
197 }
198 }
199 // Propagate to sum_var.
200 target_var_->SetRange(RootMin(), RootMax());
201
202 // Push down.
203 SumChanged();
204 }
205
206 void SumChanged() {
207 if (target_var_->Max() == RootMin() && target_var_->Max() != kint64max) {
208 // We can fix all terms to min.
209 for (int i = 0; i < vars_.size(); ++i) {
210 vars_[i]->SetValue(vars_[i]->Min());
211 }
212 } else if (target_var_->Min() == RootMax() &&
213 target_var_->Min() != kint64min) {
214 // We can fix all terms to max.
215 for (int i = 0; i < vars_.size(); ++i) {
216 vars_[i]->SetValue(vars_[i]->Max());
217 }
218 } else {
219 PushDown(0, 0, target_var_->Min(), target_var_->Max());
220 }
221 }
222
223 void PushDown(int depth, int position, int64 new_min, int64 new_max) {
224 // Nothing to do?
225 if (new_min <= Min(depth, position) && new_max >= Max(depth, position)) {
226 return;
227 }
228
229 // Leaf node -> push to leaf var.
230 if (IsLeaf(depth)) {
231 vars_[position]->SetRange(new_min, new_max);
232 return;
233 }
234
235 // Standard propagation from the bounds of the sum to the
236 // individuals terms.
237
238 // These are maintained automatically in the tree structure.
239 const int64 sum_min = Min(depth, position);
240 const int64 sum_max = Max(depth, position);
241
242 // Intersect the new bounds with the computed bounds.
243 new_max = std::min(sum_max, new_max);
244 new_min = std::max(sum_min, new_min);
245
246 // Detect failure early.
247 if (new_max < sum_min || new_min > sum_max) {
248 solver()->Fail();
249 }
250
251 // Push to children nodes.
252 const int block_start = ChildStart(position);
253 const int block_end = ChildEnd(depth, position);
254 for (int i = block_start; i <= block_end; ++i) {
255 const int64 target_var_min = Min(depth + 1, i);
256 const int64 target_var_max = Max(depth + 1, i);
257 const int64 residual_min = CapSub(sum_min, target_var_min);
258 const int64 residual_max = CapSub(sum_max, target_var_max);
259 PushDown(depth + 1, i, CapSub(new_min, residual_max),
260 CapSub(new_max, residual_min));
261 }
262 // TODO(user) : Is the diameter optimization (see reference
263 // above, rule 5) useful?
264 }
265
266 void LeafChanged(int term_index) {
267 IntVar* const var = vars_[term_index];
268 PushUp(term_index, CapSub(var->Min(), var->OldMin()),
269 CapSub(var->OldMax(), var->Max()));
270 EnqueueDelayedDemon(sum_demon_); // TODO(user): Is this needed?
271 }
272
273 void PushUp(int position, int64 delta_min, int64 delta_max) {
274 DCHECK_GE(delta_max, 0);
275 DCHECK_GE(delta_min, 0);
276 DCHECK_GT(CapAdd(delta_min, delta_max), 0);
277 for (int depth = MaxDepth(); depth >= 0; --depth) {
278 ReduceRange(depth, position, delta_min, delta_max);
279 position = Parent(position);
280 }
281 DCHECK_EQ(0, position);
282 target_var_->SetRange(RootMin(), RootMax());
283 }
284
285 std::string DebugString() const override {
286 return DebugStringInternal("Sum");
287 }
288
289 void Accept(ModelVisitor* const visitor) const override {
290 AcceptInternal(ModelVisitor::kSumEqual, visitor);
291 }
292
293 private:
294 Demon* sum_demon_;
295};
296
297// This constraint implements sum(vars) == target_var.
298class SmallSumConstraint : public Constraint {
299 public:
300 SmallSumConstraint(Solver* const solver, const std::vector<IntVar*>& vars,
301 IntVar* const target_var)
302 : Constraint(solver),
303 vars_(vars),
304 target_var_(target_var),
305 computed_min_(0),
306 computed_max_(0),
307 sum_demon_(nullptr) {}
308
309 ~SmallSumConstraint() override {}
310
311 void Post() override {
312 for (int i = 0; i < vars_.size(); ++i) {
313 if (!vars_[i]->Bound()) {
314 Demon* const demon = MakeConstraintDemon1(
315 solver(), this, &SmallSumConstraint::VarChanged, "VarChanged",
316 vars_[i]);
317 vars_[i]->WhenRange(demon);
318 }
319 }
320 sum_demon_ = solver()->RegisterDemon(MakeDelayedConstraintDemon0(
321 solver(), this, &SmallSumConstraint::SumChanged, "SumChanged"));
322 target_var_->WhenRange(sum_demon_);
323 }
324
325 void InitialPropagate() override {
326 // Compute up.
327 int64 sum_min = 0;
328 int64 sum_max = 0;
329 for (IntVar* const var : vars_) {
330 sum_min = CapAdd(sum_min, var->Min());
331 sum_max = CapAdd(sum_max, var->Max());
332 }
333
334 // Propagate to sum_var.
335 computed_min_.SetValue(solver(), sum_min);
336 computed_max_.SetValue(solver(), sum_max);
337 target_var_->SetRange(sum_min, sum_max);
338
339 // Push down.
340 SumChanged();
341 }
342
343 void SumChanged() {
344 int64 new_min = target_var_->Min();
345 int64 new_max = target_var_->Max();
346 const int64 sum_min = computed_min_.Value();
347 const int64 sum_max = computed_max_.Value();
348 if (new_max == sum_min && new_max != kint64max) {
349 // We can fix all terms to min.
350 for (int i = 0; i < vars_.size(); ++i) {
351 vars_[i]->SetValue(vars_[i]->Min());
352 }
353 } else if (new_min == sum_max && new_min != kint64min) {
354 // We can fix all terms to max.
355 for (int i = 0; i < vars_.size(); ++i) {
356 vars_[i]->SetValue(vars_[i]->Max());
357 }
358 } else {
359 if (new_min > sum_min || new_max < sum_max) { // something to do.
360 // Intersect the new bounds with the computed bounds.
361 new_max = std::min(sum_max, new_max);
362 new_min = std::max(sum_min, new_min);
363
364 // Detect failure early.
365 if (new_max < sum_min || new_min > sum_max) {
366 solver()->Fail();
367 }
368
369 // Push to variables.
370 for (IntVar* const var : vars_) {
371 const int64 var_min = var->Min();
372 const int64 var_max = var->Max();
373 const int64 residual_min = CapSub(sum_min, var_min);
374 const int64 residual_max = CapSub(sum_max, var_max);
375 var->SetRange(CapSub(new_min, residual_max),
376 CapSub(new_max, residual_min));
377 }
378 }
379 }
380 }
381
382 void VarChanged(IntVar* var) {
383 const int64 delta_min = CapSub(var->Min(), var->OldMin());
384 const int64 delta_max = CapSub(var->OldMax(), var->Max());
385 computed_min_.Add(solver(), delta_min);
386 computed_max_.Add(solver(), -delta_max);
387 if (computed_max_.Value() < target_var_->Max() ||
388 computed_min_.Value() > target_var_->Min()) {
389 target_var_->SetRange(computed_min_.Value(), computed_max_.Value());
390 } else {
391 EnqueueDelayedDemon(sum_demon_);
392 }
393 }
394
395 std::string DebugString() const override {
396 return absl::StrFormat("SmallSum(%s) == %s",
398 target_var_->DebugString());
399 }
400
401 void Accept(ModelVisitor* const visitor) const override {
402 visitor->BeginVisitConstraint(ModelVisitor::kSumEqual, this);
403 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
404 vars_);
405 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
407 visitor->EndVisitConstraint(ModelVisitor::kSumEqual, this);
408 }
409
410 private:
411 const std::vector<IntVar*> vars_;
412 IntVar* target_var_;
413 NumericalRev<int64> computed_min_;
414 NumericalRev<int64> computed_max_;
415 Demon* sum_demon_;
416};
417// ----- SafeSumConstraint -----
418
419bool DetectSumOverflow(const std::vector<IntVar*>& vars) {
420 int64 sum_min = 0;
421 int64 sum_max = 0;
422 for (int i = 0; i < vars.size(); ++i) {
423 sum_min = CapAdd(sum_min, vars[i]->Min());
424 sum_max = CapAdd(sum_max, vars[i]->Max());
425 if (sum_min == kint64min || sum_max == kint64max) {
426 return true;
427 }
428 }
429 return false;
430}
431
432// This constraint implements sum(vars) == sum_var.
433class SafeSumConstraint : public TreeArrayConstraint {
434 public:
435 SafeSumConstraint(Solver* const solver, const std::vector<IntVar*>& vars,
436 IntVar* const sum_var)
437 : TreeArrayConstraint(solver, vars, sum_var), sum_demon_(nullptr) {}
438
439 ~SafeSumConstraint() override {}
440
441 void Post() override {
442 for (int i = 0; i < vars_.size(); ++i) {
443 Demon* const demon = MakeConstraintDemon1(
444 solver(), this, &SafeSumConstraint::LeafChanged, "LeafChanged", i);
445 vars_[i]->WhenRange(demon);
446 }
447 sum_demon_ = solver()->RegisterDemon(MakeDelayedConstraintDemon0(
448 solver(), this, &SafeSumConstraint::SumChanged, "SumChanged"));
449 target_var_->WhenRange(sum_demon_);
450 }
451
452 void SafeComputeNode(int depth, int position, int64* const sum_min,
453 int64* const sum_max) {
454 DCHECK_LT(depth, MaxDepth());
455 const int block_start = ChildStart(position);
456 const int block_end = ChildEnd(depth, position);
457 for (int k = block_start; k <= block_end; ++k) {
458 if (*sum_min != kint64min) {
459 *sum_min = CapAdd(*sum_min, Min(depth + 1, k));
460 }
461 if (*sum_max != kint64max) {
462 *sum_max = CapAdd(*sum_max, Max(depth + 1, k));
463 }
464 if (*sum_min == kint64min && *sum_max == kint64max) {
465 break;
466 }
467 }
468 }
469
470 void InitialPropagate() override {
471 // Copy vars to leaf nodes.
472 for (int i = 0; i < vars_.size(); ++i) {
473 InitLeaf(i, vars_[i]->Min(), vars_[i]->Max());
474 }
475 // Compute up.
476 for (int i = MaxDepth() - 1; i >= 0; --i) {
477 for (int j = 0; j < Width(i); ++j) {
478 int64 sum_min = 0;
479 int64 sum_max = 0;
480 SafeComputeNode(i, j, &sum_min, &sum_max);
481 InitNode(i, j, sum_min, sum_max);
482 }
483 }
484 // Propagate to sum_var.
485 target_var_->SetRange(RootMin(), RootMax());
486
487 // Push down.
488 SumChanged();
489 }
490
491 void SumChanged() {
492 DCHECK(CheckInternalState());
493 if (target_var_->Max() == RootMin()) {
494 // We can fix all terms to min.
495 for (int i = 0; i < vars_.size(); ++i) {
496 vars_[i]->SetValue(vars_[i]->Min());
497 }
498 } else if (target_var_->Min() == RootMax()) {
499 // We can fix all terms to max.
500 for (int i = 0; i < vars_.size(); ++i) {
501 vars_[i]->SetValue(vars_[i]->Max());
502 }
503 } else {
504 PushDown(0, 0, target_var_->Min(), target_var_->Max());
505 }
506 }
507
508 void PushDown(int depth, int position, int64 new_min, int64 new_max) {
509 // Nothing to do?
510 if (new_min <= Min(depth, position) && new_max >= Max(depth, position)) {
511 return;
512 }
513
514 // Leaf node -> push to leaf var.
515 if (IsLeaf(depth)) {
516 vars_[position]->SetRange(new_min, new_max);
517 return;
518 }
519
520 // Standard propagation from the bounds of the sum to the
521 // individuals terms.
522
523 // These are maintained automatically in the tree structure.
524 const int64 sum_min = Min(depth, position);
525 const int64 sum_max = Max(depth, position);
526
527 // Intersect the new bounds with the computed bounds.
528 new_max = std::min(sum_max, new_max);
529 new_min = std::max(sum_min, new_min);
530
531 // Detect failure early.
532 if (new_max < sum_min || new_min > sum_max) {
533 solver()->Fail();
534 }
535
536 // Push to children nodes.
537 const int block_start = ChildStart(position);
538 const int block_end = ChildEnd(depth, position);
539 for (int pos = block_start; pos <= block_end; ++pos) {
540 const int64 target_var_min = Min(depth + 1, pos);
541 const int64 residual_min =
542 sum_min != kint64min ? CapSub(sum_min, target_var_min) : kint64min;
543 const int64 target_var_max = Max(depth + 1, pos);
544 const int64 residual_max =
545 sum_max != kint64max ? CapSub(sum_max, target_var_max) : kint64max;
546 PushDown(depth + 1, pos,
547 (residual_max == kint64min ? kint64min
548 : CapSub(new_min, residual_max)),
549 (residual_min == kint64max ? kint64min
550 : CapSub(new_max, residual_min)));
551 }
552 // TODO(user) : Is the diameter optimization (see reference
553 // above, rule 5) useful?
554 }
555
556 void LeafChanged(int term_index) {
557 IntVar* const var = vars_[term_index];
558 PushUp(term_index, CapSub(var->Min(), var->OldMin()),
559 CapSub(var->OldMax(), var->Max()));
560 EnqueueDelayedDemon(sum_demon_); // TODO(user): Is this needed?
561 }
562
563 void PushUp(int position, int64 delta_min, int64 delta_max) {
564 DCHECK_GE(delta_max, 0);
565 DCHECK_GE(delta_min, 0);
566 if (CapAdd(delta_min, delta_max) == 0) {
567 // This may happen if the computation of old min/max has under/overflowed
568 // resulting in no actual change in min and max.
569 return;
570 }
571 bool delta_corrupted = false;
572 for (int depth = MaxDepth(); depth >= 0; --depth) {
573 if (Min(depth, position) != kint64min &&
574 Max(depth, position) != kint64max && delta_min != kint64max &&
575 delta_max != kint64max && !delta_corrupted) { // No overflow.
576 ReduceRange(depth, position, delta_min, delta_max);
577 } else if (depth == MaxDepth()) { // Leaf.
578 SetRange(depth, position, vars_[position]->Min(),
579 vars_[position]->Max());
580 delta_corrupted = true;
581 } else { // Recompute.
582 int64 sum_min = 0;
583 int64 sum_max = 0;
584 SafeComputeNode(depth, position, &sum_min, &sum_max);
585 if (sum_min == kint64min && sum_max == kint64max) {
586 return; // Nothing to do upward.
587 }
588 SetRange(depth, position, sum_min, sum_max);
589 delta_corrupted = true;
590 }
591 position = Parent(position);
592 }
593 DCHECK_EQ(0, position);
594 target_var_->SetRange(RootMin(), RootMax());
595 }
596
597 std::string DebugString() const override {
598 return DebugStringInternal("Sum");
599 }
600
601 void Accept(ModelVisitor* const visitor) const override {
602 AcceptInternal(ModelVisitor::kSumEqual, visitor);
603 }
604
605 private:
606 bool CheckInternalState() {
607 for (int i = 0; i < vars_.size(); ++i) {
608 CheckLeaf(i, vars_[i]->Min(), vars_[i]->Max());
609 }
610 // Check up.
611 for (int i = MaxDepth() - 1; i >= 0; --i) {
612 for (int j = 0; j < Width(i); ++j) {
613 int64 sum_min = 0;
614 int64 sum_max = 0;
615 SafeComputeNode(i, j, &sum_min, &sum_max);
616 CheckNode(i, j, sum_min, sum_max);
617 }
618 }
619 return true;
620 }
621
622 void CheckLeaf(int position, int64 var_min, int64 var_max) {
623 CheckNode(MaxDepth(), position, var_min, var_max);
624 }
625
626 void CheckNode(int depth, int position, int64 node_min, int64 node_max) {
627 DCHECK_EQ(Min(depth, position), node_min);
628 DCHECK_EQ(Max(depth, position), node_max);
629 }
630
631 Demon* sum_demon_;
632};
633
634// ---------- Min Array ----------
635
636// This constraint implements min(vars) == min_var.
637class MinConstraint : public TreeArrayConstraint {
638 public:
639 MinConstraint(Solver* const solver, const std::vector<IntVar*>& vars,
640 IntVar* const min_var)
641 : TreeArrayConstraint(solver, vars, min_var), min_demon_(nullptr) {}
642
643 ~MinConstraint() override {}
644
645 void Post() override {
646 for (int i = 0; i < vars_.size(); ++i) {
647 Demon* const demon = MakeConstraintDemon1(
648 solver(), this, &MinConstraint::LeafChanged, "LeafChanged", i);
649 vars_[i]->WhenRange(demon);
650 }
651 min_demon_ = solver()->RegisterDemon(MakeDelayedConstraintDemon0(
652 solver(), this, &MinConstraint::MinVarChanged, "MinVarChanged"));
653 target_var_->WhenRange(min_demon_);
654 }
655
656 void InitialPropagate() override {
657 // Copy vars to leaf nodes.
658 for (int i = 0; i < vars_.size(); ++i) {
659 InitLeaf(i, vars_[i]->Min(), vars_[i]->Max());
660 }
661
662 // Compute up.
663 for (int i = MaxDepth() - 1; i >= 0; --i) {
664 for (int j = 0; j < Width(i); ++j) {
665 int64 min_min = kint64max;
666 int64 min_max = kint64max;
667 const int block_start = ChildStart(j);
668 const int block_end = ChildEnd(i, j);
669 for (int k = block_start; k <= block_end; ++k) {
670 min_min = std::min(min_min, Min(i + 1, k));
671 min_max = std::min(min_max, Max(i + 1, k));
672 }
673 InitNode(i, j, min_min, min_max);
674 }
675 }
676 // Propagate to min_var.
677 target_var_->SetRange(RootMin(), RootMax());
678
679 // Push down.
680 MinVarChanged();
681 }
682
683 void MinVarChanged() {
684 PushDown(0, 0, target_var_->Min(), target_var_->Max());
685 }
686
687 void PushDown(int depth, int position, int64 new_min, int64 new_max) {
688 // Nothing to do?
689 if (new_min <= Min(depth, position) && new_max >= Max(depth, position)) {
690 return;
691 }
692
693 // Leaf node -> push to leaf var.
694 if (IsLeaf(depth)) {
695 vars_[position]->SetRange(new_min, new_max);
696 return;
697 }
698
699 const int64 node_min = Min(depth, position);
700 const int64 node_max = Max(depth, position);
701
702 int candidate = -1;
703 int active = 0;
704 const int block_start = ChildStart(position);
705 const int block_end = ChildEnd(depth, position);
706
707 if (new_max < node_max) {
708 // Look if only one candidat to push the max down.
709 for (int i = block_start; i <= block_end; ++i) {
710 if (Min(depth + 1, i) <= new_max) {
711 if (active++ >= 1) {
712 break;
713 }
714 candidate = i;
715 }
716 }
717 if (active == 0) {
718 solver()->Fail();
719 }
720 }
721
722 if (node_min < new_min) {
723 for (int i = block_start; i <= block_end; ++i) {
724 if (i == candidate && active == 1) {
725 PushDown(depth + 1, i, new_min, new_max);
726 } else {
727 PushDown(depth + 1, i, new_min, Max(depth + 1, i));
728 }
729 }
730 } else if (active == 1) {
731 PushDown(depth + 1, candidate, Min(depth + 1, candidate), new_max);
732 }
733 }
734
735 // TODO(user): Regroup code between Min and Max constraints.
736 void LeafChanged(int term_index) {
737 IntVar* const var = vars_[term_index];
738 SetRange(MaxDepth(), term_index, var->Min(), var->Max());
739 const int parent_depth = MaxDepth() - 1;
740 const int parent = Parent(term_index);
741 const int64 old_min = var->OldMin();
742 const int64 var_min = var->Min();
743 const int64 var_max = var->Max();
744 if ((old_min == Min(parent_depth, parent) && old_min != var_min) ||
745 var_max < Max(parent_depth, parent)) {
746 // Can influence the parent bounds.
747 PushUp(term_index);
748 }
749 }
750
751 void PushUp(int position) {
752 int depth = MaxDepth();
753 while (depth > 0) {
754 const int parent = Parent(position);
755 const int parent_depth = depth - 1;
756 int64 min_min = kint64max;
757 int64 min_max = kint64max;
758 const int block_start = ChildStart(parent);
759 const int block_end = ChildEnd(parent_depth, parent);
760 for (int k = block_start; k <= block_end; ++k) {
761 min_min = std::min(min_min, Min(depth, k));
762 min_max = std::min(min_max, Max(depth, k));
763 }
764 if (min_min > Min(parent_depth, parent) ||
765 min_max < Max(parent_depth, parent)) {
766 SetRange(parent_depth, parent, min_min, min_max);
767 } else {
768 break;
769 }
770 depth = parent_depth;
771 position = parent;
772 }
773 if (depth == 0) { // We have pushed all the way up.
774 target_var_->SetRange(RootMin(), RootMax());
775 }
776 MinVarChanged();
777 }
778
779 std::string DebugString() const override {
780 return DebugStringInternal("Min");
781 }
782
783 void Accept(ModelVisitor* const visitor) const override {
784 AcceptInternal(ModelVisitor::kMinEqual, visitor);
785 }
786
787 private:
788 Demon* min_demon_;
789};
790
791class SmallMinConstraint : public Constraint {
792 public:
793 SmallMinConstraint(Solver* const solver, const std::vector<IntVar*>& vars,
794 IntVar* const target_var)
795 : Constraint(solver),
796 vars_(vars),
797 target_var_(target_var),
798 computed_min_(0),
799 computed_max_(0) {}
800
801 ~SmallMinConstraint() override {}
802
803 void Post() override {
804 for (int i = 0; i < vars_.size(); ++i) {
805 if (!vars_[i]->Bound()) {
806 Demon* const demon = MakeConstraintDemon1(
807 solver(), this, &SmallMinConstraint::VarChanged, "VarChanged",
808 vars_[i]);
809 vars_[i]->WhenRange(demon);
810 }
811 }
812 Demon* const mdemon = solver()->RegisterDemon(MakeDelayedConstraintDemon0(
813 solver(), this, &SmallMinConstraint::MinVarChanged, "MinVarChanged"));
814 target_var_->WhenRange(mdemon);
815 }
816
817 void InitialPropagate() override {
818 int64 min_min = kint64max;
819 int64 min_max = kint64max;
820 for (IntVar* const var : vars_) {
821 min_min = std::min(min_min, var->Min());
822 min_max = std::min(min_max, var->Max());
823 }
824 computed_min_.SetValue(solver(), min_min);
825 computed_max_.SetValue(solver(), min_max);
826 // Propagate to min_var.
827 target_var_->SetRange(min_min, min_max);
828
829 // Push down.
830 MinVarChanged();
831 }
832
833 std::string DebugString() const override {
834 return absl::StrFormat("SmallMin(%s) == %s",
836 target_var_->DebugString());
837 }
838
839 void Accept(ModelVisitor* const visitor) const override {
840 visitor->BeginVisitConstraint(ModelVisitor::kMinEqual, this);
841 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
842 vars_);
843 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
845 visitor->EndVisitConstraint(ModelVisitor::kMinEqual, this);
846 }
847
848 private:
849 void VarChanged(IntVar* var) {
850 const int64 old_min = var->OldMin();
851 const int64 var_min = var->Min();
852 const int64 var_max = var->Max();
853 if ((old_min == computed_min_.Value() && old_min != var_min) ||
854 var_max < computed_max_.Value()) {
855 // Can influence the min var bounds.
856 int64 min_min = kint64max;
857 int64 min_max = kint64max;
858 for (IntVar* const var : vars_) {
859 min_min = std::min(min_min, var->Min());
860 min_max = std::min(min_max, var->Max());
861 }
862 if (min_min > computed_min_.Value() || min_max < computed_max_.Value()) {
863 computed_min_.SetValue(solver(), min_min);
864 computed_max_.SetValue(solver(), min_max);
865 target_var_->SetRange(computed_min_.Value(), computed_max_.Value());
866 }
867 }
868 MinVarChanged();
869 }
870
871 void MinVarChanged() {
872 const int64 new_min = target_var_->Min();
873 const int64 new_max = target_var_->Max();
874 // Nothing to do?
875 if (new_min <= computed_min_.Value() && new_max >= computed_max_.Value()) {
876 return;
877 }
878
879 IntVar* candidate = nullptr;
880 int active = 0;
881
882 if (new_max < computed_max_.Value()) {
883 // Look if only one candidate to push the max down.
884 for (IntVar* const var : vars_) {
885 if (var->Min() <= new_max) {
886 if (active++ >= 1) {
887 break;
888 }
889 candidate = var;
890 }
891 }
892 if (active == 0) {
893 solver()->Fail();
894 }
895 }
896 if (computed_min_.Value() < new_min) {
897 if (active == 1) {
898 candidate->SetRange(new_min, new_max);
899 } else {
900 for (IntVar* const var : vars_) {
901 var->SetMin(new_min);
902 }
903 }
904 } else if (active == 1) {
905 candidate->SetMax(new_max);
906 }
907 }
908
909 std::vector<IntVar*> vars_;
910 IntVar* const target_var_;
911 Rev<int64> computed_min_;
912 Rev<int64> computed_max_;
913};
914
915// ---------- Max Array ----------
916
917// This constraint implements max(vars) == max_var.
918class MaxConstraint : public TreeArrayConstraint {
919 public:
920 MaxConstraint(Solver* const solver, const std::vector<IntVar*>& vars,
921 IntVar* const max_var)
922 : TreeArrayConstraint(solver, vars, max_var), max_demon_(nullptr) {}
923
924 ~MaxConstraint() override {}
925
926 void Post() override {
927 for (int i = 0; i < vars_.size(); ++i) {
928 Demon* const demon = MakeConstraintDemon1(
929 solver(), this, &MaxConstraint::LeafChanged, "LeafChanged", i);
930 vars_[i]->WhenRange(demon);
931 }
932 max_demon_ = solver()->RegisterDemon(MakeDelayedConstraintDemon0(
933 solver(), this, &MaxConstraint::MaxVarChanged, "MaxVarChanged"));
934 target_var_->WhenRange(max_demon_);
935 }
936
937 void InitialPropagate() override {
938 // Copy vars to leaf nodes.
939 for (int i = 0; i < vars_.size(); ++i) {
940 InitLeaf(i, vars_[i]->Min(), vars_[i]->Max());
941 }
942
943 // Compute up.
944 for (int i = MaxDepth() - 1; i >= 0; --i) {
945 for (int j = 0; j < Width(i); ++j) {
946 int64 max_min = kint64min;
947 int64 max_max = kint64min;
948 const int block_start = ChildStart(j);
949 const int block_end = ChildEnd(i, j);
950 for (int k = block_start; k <= block_end; ++k) {
951 max_min = std::max(max_min, Min(i + 1, k));
952 max_max = std::max(max_max, Max(i + 1, k));
953 }
954 InitNode(i, j, max_min, max_max);
955 }
956 }
957 // Propagate to min_var.
958 target_var_->SetRange(RootMin(), RootMax());
959
960 // Push down.
961 MaxVarChanged();
962 }
963
964 void MaxVarChanged() {
965 PushDown(0, 0, target_var_->Min(), target_var_->Max());
966 }
967
968 void PushDown(int depth, int position, int64 new_min, int64 new_max) {
969 // Nothing to do?
970 if (new_min <= Min(depth, position) && new_max >= Max(depth, position)) {
971 return;
972 }
973
974 // Leaf node -> push to leaf var.
975 if (IsLeaf(depth)) {
976 vars_[position]->SetRange(new_min, new_max);
977 return;
978 }
979
980 const int64 node_min = Min(depth, position);
981 const int64 node_max = Max(depth, position);
982
983 int candidate = -1;
984 int active = 0;
985 const int block_start = ChildStart(position);
986 const int block_end = ChildEnd(depth, position);
987
988 if (node_min < new_min) {
989 // Look if only one candidat to push the max down.
990 for (int i = block_start; i <= block_end; ++i) {
991 if (Max(depth + 1, i) >= new_min) {
992 if (active++ >= 1) {
993 break;
994 }
995 candidate = i;
996 }
997 }
998 if (active == 0) {
999 solver()->Fail();
1000 }
1001 }
1002
1003 if (node_max > new_max) {
1004 for (int i = block_start; i <= block_end; ++i) {
1005 if (i == candidate && active == 1) {
1006 PushDown(depth + 1, i, new_min, new_max);
1007 } else {
1008 PushDown(depth + 1, i, Min(depth + 1, i), new_max);
1009 }
1010 }
1011 } else if (active == 1) {
1012 PushDown(depth + 1, candidate, new_min, Max(depth + 1, candidate));
1013 }
1014 }
1015
1016 void LeafChanged(int term_index) {
1017 IntVar* const var = vars_[term_index];
1018 SetRange(MaxDepth(), term_index, var->Min(), var->Max());
1019 const int parent_depth = MaxDepth() - 1;
1020 const int parent = Parent(term_index);
1021 const int64 old_max = var->OldMax();
1022 const int64 var_min = var->Min();
1023 const int64 var_max = var->Max();
1024 if ((old_max == Max(parent_depth, parent) && old_max != var_max) ||
1025 var_min > Min(parent_depth, parent)) {
1026 // Can influence the parent bounds.
1027 PushUp(term_index);
1028 }
1029 }
1030
1031 void PushUp(int position) {
1032 int depth = MaxDepth();
1033 while (depth > 0) {
1034 const int parent = Parent(position);
1035 const int parent_depth = depth - 1;
1036 int64 max_min = kint64min;
1037 int64 max_max = kint64min;
1038 const int block_start = ChildStart(parent);
1039 const int block_end = ChildEnd(parent_depth, parent);
1040 for (int k = block_start; k <= block_end; ++k) {
1041 max_min = std::max(max_min, Min(depth, k));
1042 max_max = std::max(max_max, Max(depth, k));
1043 }
1044 if (max_min > Min(parent_depth, parent) ||
1045 max_max < Max(parent_depth, parent)) {
1046 SetRange(parent_depth, parent, max_min, max_max);
1047 } else {
1048 break;
1049 }
1050 depth = parent_depth;
1051 position = parent;
1052 }
1053 if (depth == 0) { // We have pushed all the way up.
1054 target_var_->SetRange(RootMin(), RootMax());
1055 }
1056 MaxVarChanged();
1057 }
1058
1059 std::string DebugString() const override {
1060 return DebugStringInternal("Max");
1061 }
1062
1063 void Accept(ModelVisitor* const visitor) const override {
1064 AcceptInternal(ModelVisitor::kMaxEqual, visitor);
1065 }
1066
1067 private:
1068 Demon* max_demon_;
1069};
1070
1071class SmallMaxConstraint : public Constraint {
1072 public:
1073 SmallMaxConstraint(Solver* const solver, const std::vector<IntVar*>& vars,
1074 IntVar* const target_var)
1075 : Constraint(solver),
1076 vars_(vars),
1077 target_var_(target_var),
1078 computed_min_(0),
1079 computed_max_(0) {}
1080
1081 ~SmallMaxConstraint() override {}
1082
1083 void Post() override {
1084 for (int i = 0; i < vars_.size(); ++i) {
1085 if (!vars_[i]->Bound()) {
1086 Demon* const demon = MakeConstraintDemon1(
1087 solver(), this, &SmallMaxConstraint::VarChanged, "VarChanged",
1088 vars_[i]);
1089 vars_[i]->WhenRange(demon);
1090 }
1091 }
1092 Demon* const mdemon = solver()->RegisterDemon(MakeDelayedConstraintDemon0(
1093 solver(), this, &SmallMaxConstraint::MaxVarChanged, "MinVarChanged"));
1094 target_var_->WhenRange(mdemon);
1095 }
1096
1097 void InitialPropagate() override {
1098 int64 max_min = kint64min;
1099 int64 max_max = kint64min;
1100 for (IntVar* const var : vars_) {
1101 max_min = std::max(max_min, var->Min());
1102 max_max = std::max(max_max, var->Max());
1103 }
1104 computed_min_.SetValue(solver(), max_min);
1105 computed_max_.SetValue(solver(), max_max);
1106 // Propagate to min_var.
1107 target_var_->SetRange(max_min, max_max);
1108
1109 // Push down.
1110 MaxVarChanged();
1111 }
1112
1113 std::string DebugString() const override {
1114 return absl::StrFormat("SmallMax(%s) == %s",
1116 target_var_->DebugString());
1117 }
1118
1119 void Accept(ModelVisitor* const visitor) const override {
1120 visitor->BeginVisitConstraint(ModelVisitor::kMaxEqual, this);
1121 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1122 vars_);
1123 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1124 target_var_);
1125 visitor->EndVisitConstraint(ModelVisitor::kMaxEqual, this);
1126 }
1127
1128 private:
1129 void VarChanged(IntVar* var) {
1130 const int64 old_max = var->OldMax();
1131 const int64 var_min = var->Min();
1132 const int64 var_max = var->Max();
1133 if ((old_max == computed_max_.Value() && old_max != var_max) ||
1134 var_min > computed_min_.Value()) { // REWRITE
1135 // Can influence the min var bounds.
1136 int64 max_min = kint64min;
1137 int64 max_max = kint64min;
1138 for (IntVar* const var : vars_) {
1139 max_min = std::max(max_min, var->Min());
1140 max_max = std::max(max_max, var->Max());
1141 }
1142 if (max_min > computed_min_.Value() || max_max < computed_max_.Value()) {
1143 computed_min_.SetValue(solver(), max_min);
1144 computed_max_.SetValue(solver(), max_max);
1145 target_var_->SetRange(computed_min_.Value(), computed_max_.Value());
1146 }
1147 }
1148 MaxVarChanged();
1149 }
1150
1151 void MaxVarChanged() {
1152 const int64 new_min = target_var_->Min();
1153 const int64 new_max = target_var_->Max();
1154 // Nothing to do?
1155 if (new_min <= computed_min_.Value() && new_max >= computed_max_.Value()) {
1156 return;
1157 }
1158
1159 IntVar* candidate = nullptr;
1160 int active = 0;
1161
1162 if (new_min > computed_min_.Value()) {
1163 // Look if only one candidate to push the max down.
1164 for (IntVar* const var : vars_) {
1165 if (var->Max() >= new_min) {
1166 if (active++ >= 1) {
1167 break;
1168 }
1169 candidate = var;
1170 }
1171 }
1172 if (active == 0) {
1173 solver()->Fail();
1174 }
1175 }
1176 if (computed_max_.Value() > new_max) {
1177 if (active == 1) {
1178 candidate->SetRange(new_min, new_max);
1179 } else {
1180 for (IntVar* const var : vars_) {
1181 var->SetMax(new_max);
1182 }
1183 }
1184 } else if (active == 1) {
1185 candidate->SetMin(new_min);
1186 }
1187 }
1188
1189 std::vector<IntVar*> vars_;
1190 IntVar* const target_var_;
1191 Rev<int64> computed_min_;
1192 Rev<int64> computed_max_;
1193};
1194
1195// Boolean And and Ors
1196
1197class ArrayBoolAndEq : public CastConstraint {
1198 public:
1199 ArrayBoolAndEq(Solver* const s, const std::vector<IntVar*>& vars,
1200 IntVar* const target)
1201 : CastConstraint(s, target),
1202 vars_(vars),
1203 demons_(vars.size()),
1204 unbounded_(0) {}
1205
1206 ~ArrayBoolAndEq() override {}
1207
1208 void Post() override {
1209 for (int i = 0; i < vars_.size(); ++i) {
1210 if (!vars_[i]->Bound()) {
1211 demons_[i] =
1212 MakeConstraintDemon1(solver(), this, &ArrayBoolAndEq::PropagateVar,
1213 "PropagateVar", vars_[i]);
1214 vars_[i]->WhenBound(demons_[i]);
1215 }
1216 }
1217 if (!target_var_->Bound()) {
1218 Demon* const target_demon = MakeConstraintDemon0(
1219 solver(), this, &ArrayBoolAndEq::PropagateTarget, "PropagateTarget");
1220 target_var_->WhenBound(target_demon);
1221 }
1222 }
1223
1224 void InitialPropagate() override {
1225 target_var_->SetRange(0, 1);
1226 if (target_var_->Min() == 1) {
1227 for (int i = 0; i < vars_.size(); ++i) {
1228 vars_[i]->SetMin(1);
1229 }
1230 } else {
1231 int possible_zero = -1;
1232 int ones = 0;
1233 int unbounded = 0;
1234 for (int i = 0; i < vars_.size(); ++i) {
1235 if (!vars_[i]->Bound()) {
1236 unbounded++;
1237 possible_zero = i;
1238 } else if (vars_[i]->Max() == 0) {
1239 InhibitAll();
1240 target_var_->SetMax(0);
1241 return;
1242 } else {
1243 DCHECK_EQ(1, vars_[i]->Min());
1244 ones++;
1245 }
1246 }
1247 if (unbounded == 0) {
1248 target_var_->SetMin(1);
1249 } else if (target_var_->Max() == 0 && unbounded == 1) {
1250 CHECK_NE(-1, possible_zero);
1251 vars_[possible_zero]->SetMax(0);
1252 } else {
1253 unbounded_.SetValue(solver(), unbounded);
1254 }
1255 }
1256 }
1257
1258 void PropagateVar(IntVar* var) {
1259 if (var->Min() == 1) {
1260 unbounded_.Decr(solver());
1261 if (unbounded_.Value() == 0 && !decided_.Switched()) {
1262 target_var_->SetMin(1);
1263 decided_.Switch(solver());
1264 } else if (target_var_->Max() == 0 && unbounded_.Value() == 1 &&
1265 !decided_.Switched()) {
1266 ForceToZero();
1267 }
1268 } else {
1269 InhibitAll();
1270 target_var_->SetMax(0);
1271 }
1272 }
1273
1274 void PropagateTarget() {
1275 if (target_var_->Min() == 1) {
1276 for (int i = 0; i < vars_.size(); ++i) {
1277 vars_[i]->SetMin(1);
1278 }
1279 } else {
1280 if (unbounded_.Value() == 1 && !decided_.Switched()) {
1281 ForceToZero();
1282 }
1283 }
1284 }
1285
1286 std::string DebugString() const override {
1287 return absl::StrFormat("And(%s) == %s", JoinDebugStringPtr(vars_, ", "),
1288 target_var_->DebugString());
1289 }
1290
1291 void Accept(ModelVisitor* const visitor) const override {
1292 visitor->BeginVisitConstraint(ModelVisitor::kMinEqual, this);
1293 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1294 vars_);
1295 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1296 target_var_);
1297 visitor->EndVisitConstraint(ModelVisitor::kMinEqual, this);
1298 }
1299
1300 private:
1301 void InhibitAll() {
1302 for (int i = 0; i < demons_.size(); ++i) {
1303 if (demons_[i] != nullptr) {
1304 demons_[i]->inhibit(solver());
1305 }
1306 }
1307 }
1308
1309 void ForceToZero() {
1310 for (int i = 0; i < vars_.size(); ++i) {
1311 if (vars_[i]->Min() == 0) {
1312 vars_[i]->SetValue(0);
1313 decided_.Switch(solver());
1314 return;
1315 }
1316 }
1317 solver()->Fail();
1318 }
1319
1320 const std::vector<IntVar*> vars_;
1321 std::vector<Demon*> demons_;
1322 NumericalRev<int> unbounded_;
1323 RevSwitch decided_;
1324};
1325
1326class ArrayBoolOrEq : public CastConstraint {
1327 public:
1328 ArrayBoolOrEq(Solver* const s, const std::vector<IntVar*>& vars,
1329 IntVar* const target)
1330 : CastConstraint(s, target),
1331 vars_(vars),
1332 demons_(vars.size()),
1333 unbounded_(0) {}
1334
1335 ~ArrayBoolOrEq() override {}
1336
1337 void Post() override {
1338 for (int i = 0; i < vars_.size(); ++i) {
1339 if (!vars_[i]->Bound()) {
1340 demons_[i] =
1341 MakeConstraintDemon1(solver(), this, &ArrayBoolOrEq::PropagateVar,
1342 "PropagateVar", vars_[i]);
1343 vars_[i]->WhenBound(demons_[i]);
1344 }
1345 }
1346 if (!target_var_->Bound()) {
1347 Demon* const target_demon = MakeConstraintDemon0(
1348 solver(), this, &ArrayBoolOrEq::PropagateTarget, "PropagateTarget");
1349 target_var_->WhenBound(target_demon);
1350 }
1351 }
1352
1353 void InitialPropagate() override {
1354 target_var_->SetRange(0, 1);
1355 if (target_var_->Max() == 0) {
1356 for (int i = 0; i < vars_.size(); ++i) {
1357 vars_[i]->SetMax(0);
1358 }
1359 } else {
1360 int zeros = 0;
1361 int possible_one = -1;
1362 int unbounded = 0;
1363 for (int i = 0; i < vars_.size(); ++i) {
1364 if (!vars_[i]->Bound()) {
1365 unbounded++;
1366 possible_one = i;
1367 } else if (vars_[i]->Min() == 1) {
1368 InhibitAll();
1369 target_var_->SetMin(1);
1370 return;
1371 } else {
1372 DCHECK_EQ(0, vars_[i]->Max());
1373 zeros++;
1374 }
1375 }
1376 if (unbounded == 0) {
1377 target_var_->SetMax(0);
1378 } else if (target_var_->Min() == 1 && unbounded == 1) {
1379 CHECK_NE(-1, possible_one);
1380 vars_[possible_one]->SetMin(1);
1381 } else {
1382 unbounded_.SetValue(solver(), unbounded);
1383 }
1384 }
1385 }
1386
1387 void PropagateVar(IntVar* var) {
1388 if (var->Min() == 0) {
1389 unbounded_.Decr(solver());
1390 if (unbounded_.Value() == 0 && !decided_.Switched()) {
1391 target_var_->SetMax(0);
1392 decided_.Switch(solver());
1393 }
1394 if (target_var_->Min() == 1 && unbounded_.Value() == 1 &&
1395 !decided_.Switched()) {
1396 ForceToOne();
1397 }
1398 } else {
1399 InhibitAll();
1400 target_var_->SetMin(1);
1401 }
1402 }
1403
1404 void PropagateTarget() {
1405 if (target_var_->Max() == 0) {
1406 for (int i = 0; i < vars_.size(); ++i) {
1407 vars_[i]->SetMax(0);
1408 }
1409 } else {
1410 if (unbounded_.Value() == 1 && !decided_.Switched()) {
1411 ForceToOne();
1412 }
1413 }
1414 }
1415
1416 std::string DebugString() const override {
1417 return absl::StrFormat("Or(%s) == %s", JoinDebugStringPtr(vars_, ", "),
1418 target_var_->DebugString());
1419 }
1420
1421 void Accept(ModelVisitor* const visitor) const override {
1422 visitor->BeginVisitConstraint(ModelVisitor::kMaxEqual, this);
1423 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1424 vars_);
1425 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1426 target_var_);
1427 visitor->EndVisitConstraint(ModelVisitor::kMaxEqual, this);
1428 }
1429
1430 private:
1431 void InhibitAll() {
1432 for (int i = 0; i < demons_.size(); ++i) {
1433 if (demons_[i] != nullptr) {
1434 demons_[i]->inhibit(solver());
1435 }
1436 }
1437 }
1438
1439 void ForceToOne() {
1440 for (int i = 0; i < vars_.size(); ++i) {
1441 if (vars_[i]->Max() == 1) {
1442 vars_[i]->SetValue(1);
1443 decided_.Switch(solver());
1444 return;
1445 }
1446 }
1447 solver()->Fail();
1448 }
1449
1450 const std::vector<IntVar*> vars_;
1451 std::vector<Demon*> demons_;
1452 NumericalRev<int> unbounded_;
1453 RevSwitch decided_;
1454};
1455
1456// ---------- Specialized cases ----------
1457
1458class BaseSumBooleanConstraint : public Constraint {
1459 public:
1460 BaseSumBooleanConstraint(Solver* const s, const std::vector<IntVar*>& vars)
1461 : Constraint(s), vars_(vars) {}
1462
1463 ~BaseSumBooleanConstraint() override {}
1464
1465 protected:
1466 std::string DebugStringInternal(const std::string& name) const {
1467 return absl::StrFormat("%s(%s)", name, JoinDebugStringPtr(vars_, ", "));
1468 }
1469
1470 const std::vector<IntVar*> vars_;
1471 RevSwitch inactive_;
1472};
1473
1474// ----- Sum of Boolean <= 1 -----
1475
1476class SumBooleanLessOrEqualToOne : public BaseSumBooleanConstraint {
1477 public:
1478 SumBooleanLessOrEqualToOne(Solver* const s, const std::vector<IntVar*>& vars)
1479 : BaseSumBooleanConstraint(s, vars) {}
1480
1481 ~SumBooleanLessOrEqualToOne() override {}
1482
1483 void Post() override {
1484 for (int i = 0; i < vars_.size(); ++i) {
1485 if (!vars_[i]->Bound()) {
1486 Demon* u = MakeConstraintDemon1(solver(), this,
1487 &SumBooleanLessOrEqualToOne::Update,
1488 "Update", vars_[i]);
1489 vars_[i]->WhenBound(u);
1490 }
1491 }
1492 }
1493
1494 void InitialPropagate() override {
1495 for (int i = 0; i < vars_.size(); ++i) {
1496 if (vars_[i]->Min() == 1) {
1497 PushAllToZeroExcept(vars_[i]);
1498 return;
1499 }
1500 }
1501 }
1502
1503 void Update(IntVar* var) {
1504 if (!inactive_.Switched()) {
1505 DCHECK(var->Bound());
1506 if (var->Min() == 1) {
1507 PushAllToZeroExcept(var);
1508 }
1509 }
1510 }
1511
1512 void PushAllToZeroExcept(IntVar* var) {
1513 inactive_.Switch(solver());
1514 for (int i = 0; i < vars_.size(); ++i) {
1515 IntVar* const other = vars_[i];
1516 if (other != var && other->Max() != 0) {
1517 other->SetMax(0);
1518 }
1519 }
1520 }
1521
1522 std::string DebugString() const override {
1523 return DebugStringInternal("SumBooleanLessOrEqualToOne");
1524 }
1525
1526 void Accept(ModelVisitor* const visitor) const override {
1527 visitor->BeginVisitConstraint(ModelVisitor::kSumLessOrEqual, this);
1528 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1529 vars_);
1530 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, 1);
1531 visitor->EndVisitConstraint(ModelVisitor::kSumLessOrEqual, this);
1532 }
1533};
1534
1535// ----- Sum of Boolean >= 1 -----
1536
1537// We implement this one as a Max(array) == 1.
1538
1539class SumBooleanGreaterOrEqualToOne : public BaseSumBooleanConstraint {
1540 public:
1541 SumBooleanGreaterOrEqualToOne(Solver* const s,
1542 const std::vector<IntVar*>& vars);
1543 ~SumBooleanGreaterOrEqualToOne() override {}
1544
1545 void Post() override;
1546 void InitialPropagate() override;
1547
1548 void Update(int index);
1549 void UpdateVar();
1550
1551 std::string DebugString() const override;
1552
1553 void Accept(ModelVisitor* const visitor) const override {
1554 visitor->BeginVisitConstraint(ModelVisitor::kSumGreaterOrEqual, this);
1555 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1556 vars_);
1557 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, 1);
1558 visitor->EndVisitConstraint(ModelVisitor::kSumGreaterOrEqual, this);
1559 }
1560
1561 private:
1562 RevBitSet bits_;
1563};
1564
1565SumBooleanGreaterOrEqualToOne::SumBooleanGreaterOrEqualToOne(
1566 Solver* const s, const std::vector<IntVar*>& vars)
1567 : BaseSumBooleanConstraint(s, vars), bits_(vars.size()) {}
1568
1569void SumBooleanGreaterOrEqualToOne::Post() {
1570 for (int i = 0; i < vars_.size(); ++i) {
1571 Demon* d = MakeConstraintDemon1(
1572 solver(), this, &SumBooleanGreaterOrEqualToOne::Update, "Update", i);
1573 vars_[i]->WhenRange(d);
1574 }
1575}
1576
1577void SumBooleanGreaterOrEqualToOne::InitialPropagate() {
1578 for (int i = 0; i < vars_.size(); ++i) {
1579 IntVar* const var = vars_[i];
1580 if (var->Min() == 1LL) {
1581 inactive_.Switch(solver());
1582 return;
1583 }
1584 if (var->Max() == 1LL) {
1585 bits_.SetToOne(solver(), i);
1586 }
1587 }
1588 if (bits_.IsCardinalityZero()) {
1589 solver()->Fail();
1590 } else if (bits_.IsCardinalityOne()) {
1591 vars_[bits_.GetFirstBit(0)]->SetValue(int64{1});
1592 inactive_.Switch(solver());
1593 }
1594}
1595
1596void SumBooleanGreaterOrEqualToOne::Update(int index) {
1597 if (!inactive_.Switched()) {
1598 if (vars_[index]->Min() == 1LL) { // Bound to 1.
1599 inactive_.Switch(solver());
1600 } else {
1601 bits_.SetToZero(solver(), index);
1602 if (bits_.IsCardinalityZero()) {
1603 solver()->Fail();
1604 } else if (bits_.IsCardinalityOne()) {
1605 vars_[bits_.GetFirstBit(0)]->SetValue(int64{1});
1606 inactive_.Switch(solver());
1607 }
1608 }
1609 }
1610}
1611
1612std::string SumBooleanGreaterOrEqualToOne::DebugString() const {
1613 return DebugStringInternal("SumBooleanGreaterOrEqualToOne");
1614}
1615
1616// ----- Sum of Boolean == 1 -----
1617
1618class SumBooleanEqualToOne : public BaseSumBooleanConstraint {
1619 public:
1620 SumBooleanEqualToOne(Solver* const s, const std::vector<IntVar*>& vars)
1621 : BaseSumBooleanConstraint(s, vars), active_vars_(0) {}
1622
1623 ~SumBooleanEqualToOne() override {}
1624
1625 void Post() override {
1626 for (int i = 0; i < vars_.size(); ++i) {
1627 Demon* u = MakeConstraintDemon1(
1628 solver(), this, &SumBooleanEqualToOne::Update, "Update", i);
1629 vars_[i]->WhenBound(u);
1630 }
1631 }
1632
1633 void InitialPropagate() override {
1634 int min1 = 0;
1635 int max1 = 0;
1636 int index_min = -1;
1637 int index_max = -1;
1638 for (int i = 0; i < vars_.size(); ++i) {
1639 const IntVar* const var = vars_[i];
1640 if (var->Min() == 1) {
1641 min1++;
1642 index_min = i;
1643 }
1644 if (var->Max() == 1) {
1645 max1++;
1646 index_max = i;
1647 }
1648 }
1649 if (min1 > 1 || max1 == 0) {
1650 solver()->Fail();
1651 } else if (min1 == 1) {
1652 DCHECK_NE(-1, index_min);
1653 PushAllToZeroExcept(index_min);
1654 } else if (max1 == 1) {
1655 DCHECK_NE(-1, index_max);
1656 vars_[index_max]->SetValue(1);
1657 inactive_.Switch(solver());
1658 } else {
1659 active_vars_.SetValue(solver(), max1);
1660 }
1661 }
1662
1663 void Update(int index) {
1664 if (!inactive_.Switched()) {
1665 DCHECK(vars_[index]->Bound());
1666 const int64 value = vars_[index]->Min(); // Faster than Value().
1667 if (value == 0) {
1668 active_vars_.Decr(solver());
1669 DCHECK_GE(active_vars_.Value(), 0);
1670 if (active_vars_.Value() == 0) {
1671 solver()->Fail();
1672 } else if (active_vars_.Value() == 1) {
1673 bool found = false;
1674 for (int i = 0; i < vars_.size(); ++i) {
1675 IntVar* const var = vars_[i];
1676 if (var->Max() == 1) {
1677 var->SetValue(1);
1678 PushAllToZeroExcept(i);
1679 found = true;
1680 break;
1681 }
1682 }
1683 if (!found) {
1684 solver()->Fail();
1685 }
1686 }
1687 } else {
1688 PushAllToZeroExcept(index);
1689 }
1690 }
1691 }
1692
1693 void PushAllToZeroExcept(int index) {
1694 inactive_.Switch(solver());
1695 for (int i = 0; i < vars_.size(); ++i) {
1696 if (i != index && vars_[i]->Max() != 0) {
1697 vars_[i]->SetMax(0);
1698 }
1699 }
1700 }
1701
1702 std::string DebugString() const override {
1703 return DebugStringInternal("SumBooleanEqualToOne");
1704 }
1705
1706 void Accept(ModelVisitor* const visitor) const override {
1707 visitor->BeginVisitConstraint(ModelVisitor::kSumEqual, this);
1708 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1709 vars_);
1710 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, 1);
1711 visitor->EndVisitConstraint(ModelVisitor::kSumEqual, this);
1712 }
1713
1714 private:
1715 NumericalRev<int> active_vars_;
1716};
1717
1718// ----- Sum of Boolean Equal To Var -----
1719
1720class SumBooleanEqualToVar : public BaseSumBooleanConstraint {
1721 public:
1722 SumBooleanEqualToVar(Solver* const s, const std::vector<IntVar*>& bool_vars,
1723 IntVar* const sum_var)
1724 : BaseSumBooleanConstraint(s, bool_vars),
1725 num_possible_true_vars_(0),
1726 num_always_true_vars_(0),
1727 sum_var_(sum_var) {}
1728
1729 ~SumBooleanEqualToVar() override {}
1730
1731 void Post() override {
1732 for (int i = 0; i < vars_.size(); ++i) {
1733 Demon* const u = MakeConstraintDemon1(
1734 solver(), this, &SumBooleanEqualToVar::Update, "Update", i);
1735 vars_[i]->WhenBound(u);
1736 }
1737 if (!sum_var_->Bound()) {
1738 Demon* const u = MakeConstraintDemon0(
1739 solver(), this, &SumBooleanEqualToVar::UpdateVar, "UpdateVar");
1740 sum_var_->WhenRange(u);
1741 }
1742 }
1743
1744 void InitialPropagate() override {
1745 int num_always_true_vars = 0;
1746 int possible_true = 0;
1747 for (int i = 0; i < vars_.size(); ++i) {
1748 const IntVar* const var = vars_[i];
1749 if (var->Min() == 1) {
1750 num_always_true_vars++;
1751 }
1752 if (var->Max() == 1) {
1753 possible_true++;
1754 }
1755 }
1756 sum_var_->SetRange(num_always_true_vars, possible_true);
1757 const int64 var_min = sum_var_->Min();
1758 const int64 var_max = sum_var_->Max();
1759 if (num_always_true_vars == var_max && possible_true > var_max) {
1760 PushAllUnboundToZero();
1761 } else if (possible_true == var_min && num_always_true_vars < var_min) {
1762 PushAllUnboundToOne();
1763 } else {
1764 num_possible_true_vars_.SetValue(solver(), possible_true);
1765 num_always_true_vars_.SetValue(solver(), num_always_true_vars);
1766 }
1767 }
1768
1769 void UpdateVar() {
1770 if (!inactive_.Switched()) {
1771 if (num_possible_true_vars_.Value() == sum_var_->Min()) {
1772 PushAllUnboundToOne();
1773 sum_var_->SetValue(num_possible_true_vars_.Value());
1774 } else if (num_always_true_vars_.Value() == sum_var_->Max()) {
1775 PushAllUnboundToZero();
1776 sum_var_->SetValue(num_always_true_vars_.Value());
1777 }
1778 }
1779 }
1780
1781 void Update(int index) {
1782 if (!inactive_.Switched()) {
1783 DCHECK(vars_[index]->Bound());
1784 const int64 value = vars_[index]->Min(); // Faster than Value().
1785 if (value == 0) {
1786 num_possible_true_vars_.Decr(solver());
1787 sum_var_->SetRange(num_always_true_vars_.Value(),
1788 num_possible_true_vars_.Value());
1789 if (num_possible_true_vars_.Value() == sum_var_->Min()) {
1790 PushAllUnboundToOne();
1791 }
1792 } else {
1793 DCHECK_EQ(1, value);
1794 num_always_true_vars_.Incr(solver());
1795 sum_var_->SetRange(num_always_true_vars_.Value(),
1796 num_possible_true_vars_.Value());
1797 if (num_always_true_vars_.Value() == sum_var_->Max()) {
1798 PushAllUnboundToZero();
1799 }
1800 }
1801 }
1802 }
1803
1804 void PushAllUnboundToZero() {
1805 int64 counter = 0;
1806 inactive_.Switch(solver());
1807 for (int i = 0; i < vars_.size(); ++i) {
1808 if (vars_[i]->Min() == 0) {
1809 vars_[i]->SetValue(0);
1810 } else {
1811 counter++;
1812 }
1813 }
1814 if (counter < sum_var_->Min() || counter > sum_var_->Max()) {
1815 solver()->Fail();
1816 }
1817 }
1818
1819 void PushAllUnboundToOne() {
1820 int64 counter = 0;
1821 inactive_.Switch(solver());
1822 for (int i = 0; i < vars_.size(); ++i) {
1823 if (vars_[i]->Max() == 1) {
1824 vars_[i]->SetValue(1);
1825 counter++;
1826 }
1827 }
1828 if (counter < sum_var_->Min() || counter > sum_var_->Max()) {
1829 solver()->Fail();
1830 }
1831 }
1832
1833 std::string DebugString() const override {
1834 return absl::StrFormat("%s == %s", DebugStringInternal("SumBoolean"),
1835 sum_var_->DebugString());
1836 }
1837
1838 void Accept(ModelVisitor* const visitor) const override {
1839 visitor->BeginVisitConstraint(ModelVisitor::kSumEqual, this);
1840 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1841 vars_);
1842 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1843 sum_var_);
1844 visitor->EndVisitConstraint(ModelVisitor::kSumEqual, this);
1845 }
1846
1847 private:
1848 NumericalRev<int> num_possible_true_vars_;
1849 NumericalRev<int> num_always_true_vars_;
1850 IntVar* const sum_var_;
1851};
1852
1853// ---------- ScalProd ----------
1854
1855// ----- Boolean Scal Prod -----
1856
1857struct Container {
1858 IntVar* var;
1860 Container(IntVar* v, int64 c) : var(v), coef(c) {}
1861 bool operator<(const Container& c) const { return (coef < c.coef); }
1862};
1863
1864// This method will sort both vars and coefficients in increasing
1865// coefficient order. Vars with null coefficients will be
1866// removed. Bound vars will be collected and the sum of the
1867// corresponding products (when the var is bound to 1) is returned by
1868// this method.
1869// If keep_inside is true, the constant will be added back into the
1870// scalprod as IntConst(1) * constant.
1871int64 SortBothChangeConstant(std::vector<IntVar*>* const vars,
1872 std::vector<int64>* const coefs,
1873 bool keep_inside) {
1874 CHECK(vars != nullptr);
1875 CHECK(coefs != nullptr);
1876 if (vars->empty()) {
1877 return 0;
1878 }
1879 int64 cst = 0;
1880 std::vector<Container> to_sort;
1881 for (int index = 0; index < vars->size(); ++index) {
1882 if ((*vars)[index]->Bound()) {
1883 cst = CapAdd(cst, CapProd((*coefs)[index], (*vars)[index]->Min()));
1884 } else if ((*coefs)[index] != 0) {
1885 to_sort.push_back(Container((*vars)[index], (*coefs)[index]));
1886 }
1887 }
1888 if (keep_inside && cst != 0) {
1889 CHECK_LT(to_sort.size(), vars->size());
1890 Solver* const solver = (*vars)[0]->solver();
1891 to_sort.push_back(Container(solver->MakeIntConst(1), cst));
1892 cst = 0;
1893 }
1894 std::sort(to_sort.begin(), to_sort.end());
1895 for (int index = 0; index < to_sort.size(); ++index) {
1896 (*vars)[index] = to_sort[index].var;
1897 (*coefs)[index] = to_sort[index].coef;
1898 }
1899 vars->resize(to_sort.size());
1900 coefs->resize(to_sort.size());
1901 return cst;
1902}
1903
1904// This constraint implements sum(vars) == var. It is delayed such
1905// that propagation only occurs when all variables have been touched.
1906class BooleanScalProdLessConstant : public Constraint {
1907 public:
1908 BooleanScalProdLessConstant(Solver* const s, const std::vector<IntVar*>& vars,
1909 const std::vector<int64>& coefs,
1910 int64 upper_bound)
1911 : Constraint(s),
1912 vars_(vars),
1913 coefs_(coefs),
1914 upper_bound_(upper_bound),
1915 first_unbound_backward_(vars.size() - 1),
1916 sum_of_bound_variables_(0LL),
1917 max_coefficient_(0) {
1918 CHECK(!vars.empty());
1919 for (int i = 0; i < vars_.size(); ++i) {
1920 DCHECK_GE(coefs_[i], 0);
1921 }
1922 upper_bound_ =
1923 CapSub(upper_bound, SortBothChangeConstant(&vars_, &coefs_, false));
1924 max_coefficient_.SetValue(s, coefs_[vars_.size() - 1]);
1925 }
1926
1927 ~BooleanScalProdLessConstant() override {}
1928
1929 void Post() override {
1930 for (int var_index = 0; var_index < vars_.size(); ++var_index) {
1931 if (vars_[var_index]->Bound()) {
1932 continue;
1933 }
1934 Demon* d = MakeConstraintDemon1(solver(), this,
1935 &BooleanScalProdLessConstant::Update,
1936 "InitialPropagate", var_index);
1937 vars_[var_index]->WhenRange(d);
1938 }
1939 }
1940
1941 void PushFromTop() {
1942 const int64 slack = CapSub(upper_bound_, sum_of_bound_variables_.Value());
1943 if (slack < 0) {
1944 solver()->Fail();
1945 }
1946 if (slack < max_coefficient_.Value()) {
1947 int64 last_unbound = first_unbound_backward_.Value();
1948 for (; last_unbound >= 0; --last_unbound) {
1949 if (!vars_[last_unbound]->Bound()) {
1950 if (coefs_[last_unbound] <= slack) {
1951 max_coefficient_.SetValue(solver(), coefs_[last_unbound]);
1952 break;
1953 } else {
1954 vars_[last_unbound]->SetValue(0);
1955 }
1956 }
1957 }
1958 first_unbound_backward_.SetValue(solver(), last_unbound);
1959 }
1960 }
1961
1962 void InitialPropagate() override {
1963 Solver* const s = solver();
1964 int last_unbound = -1;
1965 int64 sum = 0LL;
1966 for (int index = 0; index < vars_.size(); ++index) {
1967 if (vars_[index]->Bound()) {
1968 const int64 value = vars_[index]->Min();
1969 sum = CapAdd(sum, CapProd(value, coefs_[index]));
1970 } else {
1971 last_unbound = index;
1972 }
1973 }
1974 sum_of_bound_variables_.SetValue(s, sum);
1975 first_unbound_backward_.SetValue(s, last_unbound);
1976 PushFromTop();
1977 }
1978
1979 void Update(int var_index) {
1980 if (vars_[var_index]->Min() == 1) {
1981 sum_of_bound_variables_.SetValue(
1982 solver(), CapAdd(sum_of_bound_variables_.Value(), coefs_[var_index]));
1983 PushFromTop();
1984 }
1985 }
1986
1987 std::string DebugString() const override {
1988 return absl::StrFormat("BooleanScalProd([%s], [%s]) <= %d)",
1990 absl::StrJoin(coefs_, ", "), upper_bound_);
1991 }
1992
1993 void Accept(ModelVisitor* const visitor) const override {
1994 visitor->BeginVisitConstraint(ModelVisitor::kScalProdLessOrEqual, this);
1995 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1996 vars_);
1997 visitor->VisitIntegerArrayArgument(ModelVisitor::kCoefficientsArgument,
1998 coefs_);
1999 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, upper_bound_);
2000 visitor->EndVisitConstraint(ModelVisitor::kScalProdLessOrEqual, this);
2001 }
2002
2003 private:
2004 std::vector<IntVar*> vars_;
2005 std::vector<int64> coefs_;
2006 int64 upper_bound_;
2007 Rev<int> first_unbound_backward_;
2008 Rev<int64> sum_of_bound_variables_;
2009 Rev<int64> max_coefficient_;
2010};
2011
2012// ----- PositiveBooleanScalProdEqVar -----
2013
2014class PositiveBooleanScalProdEqVar : public CastConstraint {
2015 public:
2016 PositiveBooleanScalProdEqVar(Solver* const s,
2017 const std::vector<IntVar*>& vars,
2018 const std::vector<int64>& coefs,
2019 IntVar* const var)
2020 : CastConstraint(s, var),
2021 vars_(vars),
2022 coefs_(coefs),
2023 first_unbound_backward_(vars.size() - 1),
2024 sum_of_bound_variables_(0LL),
2025 sum_of_all_variables_(0LL),
2026 max_coefficient_(0) {
2027 SortBothChangeConstant(&vars_, &coefs_, true);
2028 max_coefficient_.SetValue(s, coefs_[vars_.size() - 1]);
2029 }
2030
2031 ~PositiveBooleanScalProdEqVar() override {}
2032
2033 void Post() override {
2034 for (int var_index = 0; var_index < vars_.size(); ++var_index) {
2035 if (vars_[var_index]->Bound()) {
2036 continue;
2037 }
2038 Demon* const d = MakeConstraintDemon1(
2039 solver(), this, &PositiveBooleanScalProdEqVar::Update, "Update",
2040 var_index);
2041 vars_[var_index]->WhenRange(d);
2042 }
2043 if (!target_var_->Bound()) {
2044 Demon* const uv = MakeConstraintDemon0(
2045 solver(), this, &PositiveBooleanScalProdEqVar::Propagate,
2046 "Propagate");
2047 target_var_->WhenRange(uv);
2048 }
2049 }
2050
2051 void Propagate() {
2052 target_var_->SetRange(sum_of_bound_variables_.Value(),
2053 sum_of_all_variables_.Value());
2054 const int64 slack_up =
2055 CapSub(target_var_->Max(), sum_of_bound_variables_.Value());
2056 const int64 slack_down =
2057 CapSub(sum_of_all_variables_.Value(), target_var_->Min());
2058 const int64 max_coeff = max_coefficient_.Value();
2059 if (slack_down < max_coeff || slack_up < max_coeff) {
2060 int64 last_unbound = first_unbound_backward_.Value();
2061 for (; last_unbound >= 0; --last_unbound) {
2062 if (!vars_[last_unbound]->Bound()) {
2063 if (coefs_[last_unbound] > slack_up) {
2064 vars_[last_unbound]->SetValue(0);
2065 } else if (coefs_[last_unbound] > slack_down) {
2066 vars_[last_unbound]->SetValue(1);
2067 } else {
2068 max_coefficient_.SetValue(solver(), coefs_[last_unbound]);
2069 break;
2070 }
2071 }
2072 }
2073 first_unbound_backward_.SetValue(solver(), last_unbound);
2074 }
2075 }
2076
2077 void InitialPropagate() override {
2078 Solver* const s = solver();
2079 int last_unbound = -1;
2080 int64 sum_bound = 0;
2081 int64 sum_all = 0;
2082 for (int index = 0; index < vars_.size(); ++index) {
2083 const int64 value = CapProd(vars_[index]->Max(), coefs_[index]);
2084 sum_all = CapAdd(sum_all, value);
2085 if (vars_[index]->Bound()) {
2086 sum_bound = CapAdd(sum_bound, value);
2087 } else {
2088 last_unbound = index;
2089 }
2090 }
2091 sum_of_bound_variables_.SetValue(s, sum_bound);
2092 sum_of_all_variables_.SetValue(s, sum_all);
2093 first_unbound_backward_.SetValue(s, last_unbound);
2094 Propagate();
2095 }
2096
2097 void Update(int var_index) {
2098 if (vars_[var_index]->Min() == 1) {
2099 sum_of_bound_variables_.SetValue(
2100 solver(), CapAdd(sum_of_bound_variables_.Value(), coefs_[var_index]));
2101 } else {
2102 sum_of_all_variables_.SetValue(
2103 solver(), CapSub(sum_of_all_variables_.Value(), coefs_[var_index]));
2104 }
2105 Propagate();
2106 }
2107
2108 std::string DebugString() const override {
2109 return absl::StrFormat("PositiveBooleanScal([%s], [%s]) == %s",
2111 absl::StrJoin(coefs_, ", "),
2112 target_var_->DebugString());
2113 }
2114
2115 void Accept(ModelVisitor* const visitor) const override {
2116 visitor->BeginVisitConstraint(ModelVisitor::kScalProdEqual, this);
2117 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
2118 vars_);
2119 visitor->VisitIntegerArrayArgument(ModelVisitor::kCoefficientsArgument,
2120 coefs_);
2121 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
2122 target_var_);
2123 visitor->EndVisitConstraint(ModelVisitor::kScalProdEqual, this);
2124 }
2125
2126 private:
2127 std::vector<IntVar*> vars_;
2128 std::vector<int64> coefs_;
2129 Rev<int> first_unbound_backward_;
2130 Rev<int64> sum_of_bound_variables_;
2131 Rev<int64> sum_of_all_variables_;
2132 Rev<int64> max_coefficient_;
2133};
2134
2135// ----- PositiveBooleanScalProd -----
2136
2137class PositiveBooleanScalProd : public BaseIntExpr {
2138 public:
2139 // this constructor will copy the array. The caller can safely delete the
2140 // exprs array himself
2141 PositiveBooleanScalProd(Solver* const s, const std::vector<IntVar*>& vars,
2142 const std::vector<int64>& coefs)
2143 : BaseIntExpr(s), vars_(vars), coefs_(coefs) {
2144 CHECK(!vars.empty());
2145 SortBothChangeConstant(&vars_, &coefs_, true);
2146 for (int i = 0; i < vars_.size(); ++i) {
2147 DCHECK_GE(coefs_[i], 0);
2148 }
2149 }
2150
2151 ~PositiveBooleanScalProd() override {}
2152
2153 int64 Min() const override {
2154 int64 min = 0;
2155 for (int i = 0; i < vars_.size(); ++i) {
2156 if (vars_[i]->Min()) {
2157 min = CapAdd(min, coefs_[i]);
2158 }
2159 }
2160 return min;
2161 }
2162
2163 void SetMin(int64 m) override { SetRange(m, kint64max); }
2164
2165 int64 Max() const override {
2166 int64 max = 0;
2167 for (int i = 0; i < vars_.size(); ++i) {
2168 if (vars_[i]->Max()) {
2169 max = CapAdd(max, coefs_[i]);
2170 }
2171 }
2172 return max;
2173 }
2174
2175 void SetMax(int64 m) override { SetRange(kint64min, m); }
2176
2177 void SetRange(int64 l, int64 u) override {
2178 int64 current_min = 0;
2179 int64 current_max = 0;
2180 int64 diameter = -1;
2181 for (int i = 0; i < vars_.size(); ++i) {
2182 const int64 coefficient = coefs_[i];
2183 const int64 var_min = CapProd(vars_[i]->Min(), coefficient);
2184 const int64 var_max = CapProd(vars_[i]->Max(), coefficient);
2185 current_min = CapAdd(current_min, var_min);
2186 current_max = CapAdd(current_max, var_max);
2187 if (var_min != var_max) { // Coefficients are increasing.
2188 diameter = CapSub(var_max, var_min);
2189 }
2190 }
2191 if (u >= current_max && l <= current_min) {
2192 return;
2193 }
2194 if (u < current_min || l > current_max) {
2195 solver()->Fail();
2196 }
2197
2198 u = std::min(current_max, u);
2199 l = std::max(l, current_min);
2200
2201 if (CapSub(u, l) > diameter) {
2202 return;
2203 }
2204
2205 for (int i = 0; i < vars_.size(); ++i) {
2206 const int64 coefficient = coefs_[i];
2207 IntVar* const var = vars_[i];
2208 const int64 new_min =
2209 CapAdd(CapSub(l, current_max), CapProd(var->Max(), coefficient));
2210 const int64 new_max =
2211 CapAdd(CapSub(u, current_min), CapProd(var->Min(), coefficient));
2212 if (new_max < 0 || new_min > coefficient || new_min > new_max) {
2213 solver()->Fail();
2214 }
2215 if (new_min > 0LL) {
2216 var->SetMin(int64{1});
2217 } else if (new_max < coefficient) {
2218 var->SetMax(int64{0});
2219 }
2220 }
2221 }
2222
2223 std::string DebugString() const override {
2224 return absl::StrFormat("PositiveBooleanScalProd([%s], [%s])",
2226 absl::StrJoin(coefs_, ", "));
2227 }
2228
2229 void WhenRange(Demon* d) override {
2230 for (int i = 0; i < vars_.size(); ++i) {
2231 vars_[i]->WhenRange(d);
2232 }
2233 }
2234 IntVar* CastToVar() override {
2235 Solver* const s = solver();
2236 int64 vmin = 0LL;
2237 int64 vmax = 0LL;
2238 Range(&vmin, &vmax);
2239 IntVar* const var = solver()->MakeIntVar(vmin, vmax);
2240 if (!vars_.empty()) {
2241 CastConstraint* const ct =
2242 s->RevAlloc(new PositiveBooleanScalProdEqVar(s, vars_, coefs_, var));
2243 s->AddCastConstraint(ct, var, this);
2244 }
2245 return var;
2246 }
2247
2248 void Accept(ModelVisitor* const visitor) const override {
2249 visitor->BeginVisitIntegerExpression(ModelVisitor::kScalProd, this);
2250 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
2251 vars_);
2252 visitor->VisitIntegerArrayArgument(ModelVisitor::kCoefficientsArgument,
2253 coefs_);
2254 visitor->EndVisitIntegerExpression(ModelVisitor::kScalProd, this);
2255 }
2256
2257 private:
2258 std::vector<IntVar*> vars_;
2259 std::vector<int64> coefs_;
2260};
2261
2262// ----- PositiveBooleanScalProdEqCst ----- (all constants >= 0)
2263
2264class PositiveBooleanScalProdEqCst : public Constraint {
2265 public:
2266 PositiveBooleanScalProdEqCst(Solver* const s,
2267 const std::vector<IntVar*>& vars,
2268 const std::vector<int64>& coefs, int64 constant)
2269 : Constraint(s),
2270 vars_(vars),
2271 coefs_(coefs),
2272 first_unbound_backward_(vars.size() - 1),
2273 sum_of_bound_variables_(0LL),
2274 sum_of_all_variables_(0LL),
2275 constant_(constant),
2276 max_coefficient_(0) {
2277 CHECK(!vars.empty());
2278 constant_ =
2279 CapSub(constant_, SortBothChangeConstant(&vars_, &coefs_, false));
2280 max_coefficient_.SetValue(s, coefs_[vars_.size() - 1]);
2281 }
2282
2283 ~PositiveBooleanScalProdEqCst() override {}
2284
2285 void Post() override {
2286 for (int var_index = 0; var_index < vars_.size(); ++var_index) {
2287 if (!vars_[var_index]->Bound()) {
2288 Demon* const d = MakeConstraintDemon1(
2289 solver(), this, &PositiveBooleanScalProdEqCst::Update, "Update",
2290 var_index);
2291 vars_[var_index]->WhenRange(d);
2292 }
2293 }
2294 }
2295
2296 void Propagate() {
2297 if (sum_of_bound_variables_.Value() > constant_ ||
2298 sum_of_all_variables_.Value() < constant_) {
2299 solver()->Fail();
2300 }
2301 const int64 slack_up = CapSub(constant_, sum_of_bound_variables_.Value());
2302 const int64 slack_down = CapSub(sum_of_all_variables_.Value(), constant_);
2303 const int64 max_coeff = max_coefficient_.Value();
2304 if (slack_down < max_coeff || slack_up < max_coeff) {
2305 int64 last_unbound = first_unbound_backward_.Value();
2306 for (; last_unbound >= 0; --last_unbound) {
2307 if (!vars_[last_unbound]->Bound()) {
2308 if (coefs_[last_unbound] > slack_up) {
2309 vars_[last_unbound]->SetValue(0);
2310 } else if (coefs_[last_unbound] > slack_down) {
2311 vars_[last_unbound]->SetValue(1);
2312 } else {
2313 max_coefficient_.SetValue(solver(), coefs_[last_unbound]);
2314 break;
2315 }
2316 }
2317 }
2318 first_unbound_backward_.SetValue(solver(), last_unbound);
2319 }
2320 }
2321
2322 void InitialPropagate() override {
2323 Solver* const s = solver();
2324 int last_unbound = -1;
2325 int64 sum_bound = 0LL;
2326 int64 sum_all = 0LL;
2327 for (int index = 0; index < vars_.size(); ++index) {
2328 const int64 value = CapProd(vars_[index]->Max(), coefs_[index]);
2329 sum_all = CapAdd(sum_all, value);
2330 if (vars_[index]->Bound()) {
2331 sum_bound = CapAdd(sum_bound, value);
2332 } else {
2333 last_unbound = index;
2334 }
2335 }
2336 sum_of_bound_variables_.SetValue(s, sum_bound);
2337 sum_of_all_variables_.SetValue(s, sum_all);
2338 first_unbound_backward_.SetValue(s, last_unbound);
2339 Propagate();
2340 }
2341
2342 void Update(int var_index) {
2343 if (vars_[var_index]->Min() == 1) {
2344 sum_of_bound_variables_.SetValue(
2345 solver(), CapAdd(sum_of_bound_variables_.Value(), coefs_[var_index]));
2346 } else {
2347 sum_of_all_variables_.SetValue(
2348 solver(), CapSub(sum_of_all_variables_.Value(), coefs_[var_index]));
2349 }
2350 Propagate();
2351 }
2352
2353 std::string DebugString() const override {
2354 return absl::StrFormat("PositiveBooleanScalProd([%s], [%s]) == %d",
2356 absl::StrJoin(coefs_, ", "), constant_);
2357 }
2358
2359 void Accept(ModelVisitor* const visitor) const override {
2360 visitor->BeginVisitConstraint(ModelVisitor::kScalProdEqual, this);
2361 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
2362 vars_);
2363 visitor->VisitIntegerArrayArgument(ModelVisitor::kCoefficientsArgument,
2364 coefs_);
2365 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, constant_);
2366 visitor->EndVisitConstraint(ModelVisitor::kScalProdEqual, this);
2367 }
2368
2369 private:
2370 std::vector<IntVar*> vars_;
2371 std::vector<int64> coefs_;
2372 Rev<int> first_unbound_backward_;
2373 Rev<int64> sum_of_bound_variables_;
2374 Rev<int64> sum_of_all_variables_;
2375 int64 constant_;
2376 Rev<int64> max_coefficient_;
2377};
2378
2379// ----- Linearizer -----
2380
2381#define IS_TYPE(type, tag) type.compare(ModelVisitor::tag) == 0
2382
2383class ExprLinearizer : public ModelParser {
2384 public:
2385 explicit ExprLinearizer(
2386 absl::flat_hash_map<IntVar*, int64>* const variables_to_coefficients)
2387 : variables_to_coefficients_(variables_to_coefficients), constant_(0) {}
2388
2389 ~ExprLinearizer() override {}
2390
2391 // Begin/End visit element.
2392 void BeginVisitModel(const std::string& solver_name) override {
2393 LOG(FATAL) << "Should not be here";
2394 }
2395
2396 void EndVisitModel(const std::string& solver_name) override {
2397 LOG(FATAL) << "Should not be here";
2398 }
2399
2400 void BeginVisitConstraint(const std::string& type_name,
2401 const Constraint* const constraint) override {
2402 LOG(FATAL) << "Should not be here";
2403 }
2404
2405 void EndVisitConstraint(const std::string& type_name,
2406 const Constraint* const constraint) override {
2407 LOG(FATAL) << "Should not be here";
2408 }
2409
2410 void BeginVisitExtension(const std::string& type) override {}
2411
2412 void EndVisitExtension(const std::string& type) override {}
2413 void BeginVisitIntegerExpression(const std::string& type_name,
2414 const IntExpr* const expr) override {
2415 BeginVisit(true);
2416 }
2417
2418 void EndVisitIntegerExpression(const std::string& type_name,
2419 const IntExpr* const expr) override {
2420 if (IS_TYPE(type_name, kSum)) {
2421 VisitSum(expr);
2422 } else if (IS_TYPE(type_name, kScalProd)) {
2423 VisitScalProd(expr);
2424 } else if (IS_TYPE(type_name, kDifference)) {
2425 VisitDifference(expr);
2426 } else if (IS_TYPE(type_name, kOpposite)) {
2427 VisitOpposite(expr);
2428 } else if (IS_TYPE(type_name, kProduct)) {
2429 VisitProduct(expr);
2430 } else if (IS_TYPE(type_name, kTrace)) {
2431 VisitTrace(expr);
2432 } else {
2433 VisitIntegerExpression(expr);
2434 }
2435 EndVisit();
2436 }
2437
2438 void VisitIntegerVariable(const IntVar* const variable,
2439 const std::string& operation, int64 value,
2440 IntVar* const delegate) override {
2441 if (operation == ModelVisitor::kSumOperation) {
2442 AddConstant(value);
2443 VisitSubExpression(delegate);
2444 } else if (operation == ModelVisitor::kDifferenceOperation) {
2445 AddConstant(value);
2446 PushMultiplier(-1);
2447 VisitSubExpression(delegate);
2448 PopMultiplier();
2449 } else if (operation == ModelVisitor::kProductOperation) {
2450 PushMultiplier(value);
2451 VisitSubExpression(delegate);
2452 PopMultiplier();
2453 } else if (operation == ModelVisitor::kTraceOperation) {
2454 VisitSubExpression(delegate);
2455 }
2456 }
2457
2458 void VisitIntegerVariable(const IntVar* const variable,
2459 IntExpr* const delegate) override {
2460 if (delegate != nullptr) {
2461 VisitSubExpression(delegate);
2462 } else {
2463 if (variable->Bound()) {
2464 AddConstant(variable->Min());
2465 } else {
2466 RegisterExpression(variable, 1);
2467 }
2468 }
2469 }
2470
2471 // Visit integer arguments.
2472 void VisitIntegerArgument(const std::string& arg_name, int64 value) override {
2473 Top()->SetIntegerArgument(arg_name, value);
2474 }
2475
2476 void VisitIntegerArrayArgument(const std::string& arg_name,
2477 const std::vector<int64>& values) override {
2478 Top()->SetIntegerArrayArgument(arg_name, values);
2479 }
2480
2481 void VisitIntegerMatrixArgument(const std::string& arg_name,
2482 const IntTupleSet& values) override {
2483 Top()->SetIntegerMatrixArgument(arg_name, values);
2484 }
2485
2486 // Visit integer expression argument.
2487 void VisitIntegerExpressionArgument(const std::string& arg_name,
2488 IntExpr* const argument) override {
2489 Top()->SetIntegerExpressionArgument(arg_name, argument);
2490 }
2491
2492 void VisitIntegerVariableArrayArgument(
2493 const std::string& arg_name,
2494 const std::vector<IntVar*>& arguments) override {
2495 Top()->SetIntegerVariableArrayArgument(arg_name, arguments);
2496 }
2497
2498 // Visit interval argument.
2499 void VisitIntervalArgument(const std::string& arg_name,
2500 IntervalVar* const argument) override {}
2501
2502 void VisitIntervalArrayArgument(
2503 const std::string& arg_name,
2504 const std::vector<IntervalVar*>& argument) override {}
2505
2506 void Visit(const IntExpr* const expr, int64 multiplier) {
2507 if (expr->Min() == expr->Max()) {
2508 constant_ = CapAdd(constant_, CapProd(expr->Min(), multiplier));
2509 } else {
2510 PushMultiplier(multiplier);
2511 expr->Accept(this);
2512 PopMultiplier();
2513 }
2514 }
2515
2516 int64 Constant() const { return constant_; }
2517
2518 std::string DebugString() const override { return "ExprLinearizer"; }
2519
2520 private:
2521 void BeginVisit(bool active) { PushArgumentHolder(); }
2522
2523 void EndVisit() { PopArgumentHolder(); }
2524
2525 void VisitSubExpression(const IntExpr* const cp_expr) {
2526 cp_expr->Accept(this);
2527 }
2528
2529 void VisitSum(const IntExpr* const cp_expr) {
2530 if (Top()->HasIntegerVariableArrayArgument(ModelVisitor::kVarsArgument)) {
2531 const std::vector<IntVar*>& cp_vars =
2532 Top()->FindIntegerVariableArrayArgumentOrDie(
2533 ModelVisitor::kVarsArgument);
2534 for (int i = 0; i < cp_vars.size(); ++i) {
2535 VisitSubExpression(cp_vars[i]);
2536 }
2537 } else if (Top()->HasIntegerExpressionArgument(
2538 ModelVisitor::kLeftArgument)) {
2539 const IntExpr* const left = Top()->FindIntegerExpressionArgumentOrDie(
2540 ModelVisitor::kLeftArgument);
2541 const IntExpr* const right = Top()->FindIntegerExpressionArgumentOrDie(
2542 ModelVisitor::kRightArgument);
2543 VisitSubExpression(left);
2544 VisitSubExpression(right);
2545 } else {
2546 const IntExpr* const expr = Top()->FindIntegerExpressionArgumentOrDie(
2547 ModelVisitor::kExpressionArgument);
2548 const int64 value =
2549 Top()->FindIntegerArgumentOrDie(ModelVisitor::kValueArgument);
2550 VisitSubExpression(expr);
2551 AddConstant(value);
2552 }
2553 }
2554
2555 void VisitScalProd(const IntExpr* const cp_expr) {
2556 const std::vector<IntVar*>& cp_vars =
2557 Top()->FindIntegerVariableArrayArgumentOrDie(
2558 ModelVisitor::kVarsArgument);
2559 const std::vector<int64>& cp_coefficients =
2560 Top()->FindIntegerArrayArgumentOrDie(
2561 ModelVisitor::kCoefficientsArgument);
2562 CHECK_EQ(cp_vars.size(), cp_coefficients.size());
2563 for (int i = 0; i < cp_vars.size(); ++i) {
2564 const int64 coefficient = cp_coefficients[i];
2565 PushMultiplier(coefficient);
2566 VisitSubExpression(cp_vars[i]);
2567 PopMultiplier();
2568 }
2569 }
2570
2571 void VisitDifference(const IntExpr* const cp_expr) {
2572 if (Top()->HasIntegerExpressionArgument(ModelVisitor::kLeftArgument)) {
2573 const IntExpr* const left = Top()->FindIntegerExpressionArgumentOrDie(
2574 ModelVisitor::kLeftArgument);
2575 const IntExpr* const right = Top()->FindIntegerExpressionArgumentOrDie(
2576 ModelVisitor::kRightArgument);
2577 VisitSubExpression(left);
2578 PushMultiplier(-1);
2579 VisitSubExpression(right);
2580 PopMultiplier();
2581 } else {
2582 const IntExpr* const expr = Top()->FindIntegerExpressionArgumentOrDie(
2583 ModelVisitor::kExpressionArgument);
2584 const int64 value =
2585 Top()->FindIntegerArgumentOrDie(ModelVisitor::kValueArgument);
2586 AddConstant(value);
2587 PushMultiplier(-1);
2588 VisitSubExpression(expr);
2589 PopMultiplier();
2590 }
2591 }
2592
2593 void VisitOpposite(const IntExpr* const cp_expr) {
2594 const IntExpr* const expr = Top()->FindIntegerExpressionArgumentOrDie(
2595 ModelVisitor::kExpressionArgument);
2596 PushMultiplier(-1);
2597 VisitSubExpression(expr);
2598 PopMultiplier();
2599 }
2600
2601 void VisitProduct(const IntExpr* const cp_expr) {
2602 if (Top()->HasIntegerExpressionArgument(
2603 ModelVisitor::kExpressionArgument)) {
2604 const IntExpr* const expr = Top()->FindIntegerExpressionArgumentOrDie(
2605 ModelVisitor::kExpressionArgument);
2606 const int64 value =
2607 Top()->FindIntegerArgumentOrDie(ModelVisitor::kValueArgument);
2608 PushMultiplier(value);
2609 VisitSubExpression(expr);
2610 PopMultiplier();
2611 } else {
2612 RegisterExpression(cp_expr, 1);
2613 }
2614 }
2615
2616 void VisitTrace(const IntExpr* const cp_expr) {
2617 const IntExpr* const expr = Top()->FindIntegerExpressionArgumentOrDie(
2618 ModelVisitor::kExpressionArgument);
2619 VisitSubExpression(expr);
2620 }
2621
2622 void VisitIntegerExpression(const IntExpr* const cp_expr) {
2623 RegisterExpression(cp_expr, 1);
2624 }
2625
2626 void RegisterExpression(const IntExpr* const expr, int64 coef) {
2627 int64& value =
2628 (*variables_to_coefficients_)[const_cast<IntExpr*>(expr)->Var()];
2629 value = CapAdd(value, CapProd(coef, multipliers_.back()));
2630 }
2631
2632 void AddConstant(int64 constant) {
2633 constant_ = CapAdd(constant_, CapProd(constant, multipliers_.back()));
2634 }
2635
2636 void PushMultiplier(int64 multiplier) {
2637 if (multipliers_.empty()) {
2638 multipliers_.push_back(multiplier);
2639 } else {
2640 multipliers_.push_back(CapProd(multiplier, multipliers_.back()));
2641 }
2642 }
2643
2644 void PopMultiplier() { multipliers_.pop_back(); }
2645
2646 // We do need a IntVar* as key, and not const IntVar*, because clients of this
2647 // class typically iterate over the map keys and use them as mutable IntVar*.
2648 absl::flat_hash_map<IntVar*, int64>* const variables_to_coefficients_;
2649 std::vector<int64> multipliers_;
2650 int64 constant_;
2651};
2652#undef IS_TYPE
2653
2654// ----- Factory functions -----
2655
2656void DeepLinearize(Solver* const solver, const std::vector<IntVar*>& pre_vars,
2657 const std::vector<int64>& pre_coefs,
2658 std::vector<IntVar*>* vars, std::vector<int64>* coefs,
2659 int64* constant) {
2660 CHECK(solver != nullptr);
2661 CHECK(vars != nullptr);
2662 CHECK(coefs != nullptr);
2663 CHECK(constant != nullptr);
2664 *constant = 0;
2665 vars->reserve(pre_vars.size());
2666 coefs->reserve(pre_coefs.size());
2667 // Try linear scan of the variables to check if there is nothing to do.
2668 bool need_linearization = false;
2669 for (int i = 0; i < pre_vars.size(); ++i) {
2670 IntVar* const variable = pre_vars[i];
2671 const int64 coefficient = pre_coefs[i];
2672 if (variable->Bound()) {
2673 *constant = CapAdd(*constant, CapProd(coefficient, variable->Min()));
2674 } else if (solver->CastExpression(variable) == nullptr) {
2675 vars->push_back(variable);
2676 coefs->push_back(coefficient);
2677 } else {
2678 need_linearization = true;
2679 vars->clear();
2680 coefs->clear();
2681 break;
2682 }
2683 }
2684 if (need_linearization) {
2685 // Instrospect the variables to simplify the sum.
2686 absl::flat_hash_map<IntVar*, int64> variables_to_coefficients;
2687 ExprLinearizer linearizer(&variables_to_coefficients);
2688 for (int i = 0; i < pre_vars.size(); ++i) {
2689 linearizer.Visit(pre_vars[i], pre_coefs[i]);
2690 }
2691 *constant = linearizer.Constant();
2692 for (const auto& variable_to_coefficient : variables_to_coefficients) {
2693 if (variable_to_coefficient.second != 0) {
2694 vars->push_back(variable_to_coefficient.first);
2695 coefs->push_back(variable_to_coefficient.second);
2696 }
2697 }
2698 }
2699}
2700
2701Constraint* MakeScalProdEqualityFct(Solver* const solver,
2702 const std::vector<IntVar*>& pre_vars,
2703 const std::vector<int64>& pre_coefs,
2704 int64 cst) {
2705 int64 constant = 0;
2706 std::vector<IntVar*> vars;
2707 std::vector<int64> coefs;
2708 DeepLinearize(solver, pre_vars, pre_coefs, &vars, &coefs, &constant);
2709 cst = CapSub(cst, constant);
2710
2711 const int size = vars.size();
2712 if (size == 0 || AreAllNull(coefs)) {
2713 return cst == 0 ? solver->MakeTrueConstraint()
2714 : solver->MakeFalseConstraint();
2715 }
2716 if (AreAllBoundOrNull(vars, coefs)) {
2717 int64 sum = 0;
2718 for (int i = 0; i < size; ++i) {
2719 sum = CapAdd(sum, CapProd(coefs[i], vars[i]->Min()));
2720 }
2721 return sum == cst ? solver->MakeTrueConstraint()
2722 : solver->MakeFalseConstraint();
2723 }
2724 if (AreAllOnes(coefs)) {
2725 return solver->MakeSumEquality(vars, cst);
2726 }
2727 if (AreAllBooleans(vars) && size > 2) {
2728 if (AreAllPositive(coefs)) {
2729 return solver->RevAlloc(
2730 new PositiveBooleanScalProdEqCst(solver, vars, coefs, cst));
2731 }
2732 if (AreAllNegative(coefs)) {
2733 std::vector<int64> opp_coefs(coefs.size());
2734 for (int i = 0; i < coefs.size(); ++i) {
2735 opp_coefs[i] = -coefs[i];
2736 }
2737 return solver->RevAlloc(
2738 new PositiveBooleanScalProdEqCst(solver, vars, opp_coefs, -cst));
2739 }
2740 }
2741
2742 // Simplications.
2743 int constants = 0;
2744 int positives = 0;
2745 int negatives = 0;
2746 for (int i = 0; i < size; ++i) {
2747 if (coefs[i] == 0 || vars[i]->Bound()) {
2748 constants++;
2749 } else if (coefs[i] > 0) {
2750 positives++;
2751 } else {
2752 negatives++;
2753 }
2754 }
2755 if (positives > 0 && negatives > 0) {
2756 std::vector<IntVar*> pos_terms;
2757 std::vector<IntVar*> neg_terms;
2758 int64 rhs = cst;
2759 for (int i = 0; i < size; ++i) {
2760 if (coefs[i] == 0 || vars[i]->Bound()) {
2761 rhs = CapSub(rhs, CapProd(coefs[i], vars[i]->Min()));
2762 } else if (coefs[i] > 0) {
2763 pos_terms.push_back(solver->MakeProd(vars[i], coefs[i])->Var());
2764 } else {
2765 neg_terms.push_back(solver->MakeProd(vars[i], -coefs[i])->Var());
2766 }
2767 }
2768 if (negatives == 1) {
2769 if (rhs != 0) {
2770 pos_terms.push_back(solver->MakeIntConst(-rhs));
2771 }
2772 return solver->MakeSumEquality(pos_terms, neg_terms[0]);
2773 } else if (positives == 1) {
2774 if (rhs != 0) {
2775 neg_terms.push_back(solver->MakeIntConst(rhs));
2776 }
2777 return solver->MakeSumEquality(neg_terms, pos_terms[0]);
2778 } else {
2779 if (rhs != 0) {
2780 neg_terms.push_back(solver->MakeIntConst(rhs));
2781 }
2782 return solver->MakeEquality(solver->MakeSum(pos_terms),
2783 solver->MakeSum(neg_terms));
2784 }
2785 } else if (positives == 1) {
2786 IntExpr* pos_term = nullptr;
2787 int64 rhs = cst;
2788 for (int i = 0; i < size; ++i) {
2789 if (coefs[i] == 0 || vars[i]->Bound()) {
2790 rhs = CapSub(rhs, CapProd(coefs[i], vars[i]->Min()));
2791 } else if (coefs[i] > 0) {
2792 pos_term = solver->MakeProd(vars[i], coefs[i]);
2793 } else {
2794 LOG(FATAL) << "Should not be here";
2795 }
2796 }
2797 return solver->MakeEquality(pos_term, rhs);
2798 } else if (negatives == 1) {
2799 IntExpr* neg_term = nullptr;
2800 int64 rhs = cst;
2801 for (int i = 0; i < size; ++i) {
2802 if (coefs[i] == 0 || vars[i]->Bound()) {
2803 rhs = CapSub(rhs, CapProd(coefs[i], vars[i]->Min()));
2804 } else if (coefs[i] > 0) {
2805 LOG(FATAL) << "Should not be here";
2806 } else {
2807 neg_term = solver->MakeProd(vars[i], -coefs[i]);
2808 }
2809 }
2810 return solver->MakeEquality(neg_term, -rhs);
2811 } else if (positives > 1) {
2812 std::vector<IntVar*> pos_terms;
2813 int64 rhs = cst;
2814 for (int i = 0; i < size; ++i) {
2815 if (coefs[i] == 0 || vars[i]->Bound()) {
2816 rhs = CapSub(rhs, CapProd(coefs[i], vars[i]->Min()));
2817 } else if (coefs[i] > 0) {
2818 pos_terms.push_back(solver->MakeProd(vars[i], coefs[i])->Var());
2819 } else {
2820 LOG(FATAL) << "Should not be here";
2821 }
2822 }
2823 return solver->MakeSumEquality(pos_terms, rhs);
2824 } else if (negatives > 1) {
2825 std::vector<IntVar*> neg_terms;
2826 int64 rhs = cst;
2827 for (int i = 0; i < size; ++i) {
2828 if (coefs[i] == 0 || vars[i]->Bound()) {
2829 rhs = CapSub(rhs, CapProd(coefs[i], vars[i]->Min()));
2830 } else if (coefs[i] > 0) {
2831 LOG(FATAL) << "Should not be here";
2832 } else {
2833 neg_terms.push_back(solver->MakeProd(vars[i], -coefs[i])->Var());
2834 }
2835 }
2836 return solver->MakeSumEquality(neg_terms, -rhs);
2837 }
2838 std::vector<IntVar*> terms;
2839 for (int i = 0; i < size; ++i) {
2840 terms.push_back(solver->MakeProd(vars[i], coefs[i])->Var());
2841 }
2842 return solver->MakeSumEquality(terms, solver->MakeIntConst(cst));
2843}
2844
2845Constraint* MakeScalProdEqualityVarFct(Solver* const solver,
2846 const std::vector<IntVar*>& pre_vars,
2847 const std::vector<int64>& pre_coefs,
2848 IntVar* const target) {
2849 int64 constant = 0;
2850 std::vector<IntVar*> vars;
2851 std::vector<int64> coefs;
2852 DeepLinearize(solver, pre_vars, pre_coefs, &vars, &coefs, &constant);
2853
2854 const int size = vars.size();
2855 if (size == 0 || AreAllNull<int64>(coefs)) {
2856 return solver->MakeEquality(target, constant);
2857 }
2858 if (AreAllOnes(coefs)) {
2859 return solver->MakeSumEquality(vars,
2860 solver->MakeSum(target, -constant)->Var());
2861 }
2862 if (AreAllBooleans(vars) && AreAllPositive<int64>(coefs)) {
2863 // TODO(user) : bench BooleanScalProdEqVar with IntConst.
2864 return solver->RevAlloc(new PositiveBooleanScalProdEqVar(
2865 solver, vars, coefs, solver->MakeSum(target, -constant)->Var()));
2866 }
2867 std::vector<IntVar*> terms;
2868 for (int i = 0; i < size; ++i) {
2869 terms.push_back(solver->MakeProd(vars[i], coefs[i])->Var());
2870 }
2871 return solver->MakeSumEquality(terms,
2872 solver->MakeSum(target, -constant)->Var());
2873}
2874
2875Constraint* MakeScalProdGreaterOrEqualFct(Solver* solver,
2876 const std::vector<IntVar*>& pre_vars,
2877 const std::vector<int64>& pre_coefs,
2878 int64 cst) {
2879 int64 constant = 0;
2880 std::vector<IntVar*> vars;
2881 std::vector<int64> coefs;
2882 DeepLinearize(solver, pre_vars, pre_coefs, &vars, &coefs, &constant);
2883 cst = CapSub(cst, constant);
2884
2885 const int size = vars.size();
2886 if (size == 0 || AreAllNull<int64>(coefs)) {
2887 return cst <= 0 ? solver->MakeTrueConstraint()
2888 : solver->MakeFalseConstraint();
2889 }
2890 if (AreAllOnes(coefs)) {
2891 return solver->MakeSumGreaterOrEqual(vars, cst);
2892 }
2893 if (cst == 1 && AreAllBooleans(vars) && AreAllPositive(coefs)) {
2894 // can move all coefs to 1.
2895 std::vector<IntVar*> terms;
2896 for (int i = 0; i < size; ++i) {
2897 if (coefs[i] > 0) {
2898 terms.push_back(vars[i]);
2899 }
2900 }
2901 return solver->MakeSumGreaterOrEqual(terms, 1);
2902 }
2903 std::vector<IntVar*> terms;
2904 for (int i = 0; i < size; ++i) {
2905 terms.push_back(solver->MakeProd(vars[i], coefs[i])->Var());
2906 }
2907 return solver->MakeSumGreaterOrEqual(terms, cst);
2908}
2909
2910Constraint* MakeScalProdLessOrEqualFct(Solver* solver,
2911 const std::vector<IntVar*>& pre_vars,
2912 const std::vector<int64>& pre_coefs,
2913 int64 upper_bound) {
2914 int64 constant = 0;
2915 std::vector<IntVar*> vars;
2916 std::vector<int64> coefs;
2917 DeepLinearize(solver, pre_vars, pre_coefs, &vars, &coefs, &constant);
2918 upper_bound = CapSub(upper_bound, constant);
2919
2920 const int size = vars.size();
2921 if (size == 0 || AreAllNull<int64>(coefs)) {
2922 return upper_bound >= 0 ? solver->MakeTrueConstraint()
2923 : solver->MakeFalseConstraint();
2924 }
2925 // TODO(user) : compute constant on the fly.
2926 if (AreAllBoundOrNull(vars, coefs)) {
2927 int64 cst = 0;
2928 for (int i = 0; i < size; ++i) {
2929 cst = CapAdd(cst, CapProd(vars[i]->Min(), coefs[i]));
2930 }
2931 return cst <= upper_bound ? solver->MakeTrueConstraint()
2932 : solver->MakeFalseConstraint();
2933 }
2934 if (AreAllOnes(coefs)) {
2935 return solver->MakeSumLessOrEqual(vars, upper_bound);
2936 }
2937 if (AreAllBooleans(vars) && AreAllPositive<int64>(coefs)) {
2938 return solver->RevAlloc(
2939 new BooleanScalProdLessConstant(solver, vars, coefs, upper_bound));
2940 }
2941 // Some simplications
2942 int constants = 0;
2943 int positives = 0;
2944 int negatives = 0;
2945 for (int i = 0; i < size; ++i) {
2946 if (coefs[i] == 0 || vars[i]->Bound()) {
2947 constants++;
2948 } else if (coefs[i] > 0) {
2949 positives++;
2950 } else {
2951 negatives++;
2952 }
2953 }
2954 if (positives > 0 && negatives > 0) {
2955 std::vector<IntVar*> pos_terms;
2956 std::vector<IntVar*> neg_terms;
2957 int64 rhs = upper_bound;
2958 for (int i = 0; i < size; ++i) {
2959 if (coefs[i] == 0 || vars[i]->Bound()) {
2960 rhs = CapSub(rhs, CapProd(coefs[i], vars[i]->Min()));
2961 } else if (coefs[i] > 0) {
2962 pos_terms.push_back(solver->MakeProd(vars[i], coefs[i])->Var());
2963 } else {
2964 neg_terms.push_back(solver->MakeProd(vars[i], -coefs[i])->Var());
2965 }
2966 }
2967 if (negatives == 1) {
2968 IntExpr* const neg_term = solver->MakeSum(neg_terms[0], rhs);
2969 return solver->MakeLessOrEqual(solver->MakeSum(pos_terms), neg_term);
2970 } else if (positives == 1) {
2971 IntExpr* const pos_term = solver->MakeSum(pos_terms[0], -rhs);
2972 return solver->MakeGreaterOrEqual(solver->MakeSum(neg_terms), pos_term);
2973 } else {
2974 if (rhs != 0) {
2975 neg_terms.push_back(solver->MakeIntConst(rhs));
2976 }
2977 return solver->MakeLessOrEqual(solver->MakeSum(pos_terms),
2978 solver->MakeSum(neg_terms));
2979 }
2980 } else if (positives == 1) {
2981 IntExpr* pos_term = nullptr;
2982 int64 rhs = upper_bound;
2983 for (int i = 0; i < size; ++i) {
2984 if (coefs[i] == 0 || vars[i]->Bound()) {
2985 rhs = CapSub(rhs, CapProd(coefs[i], vars[i]->Min()));
2986 } else if (coefs[i] > 0) {
2987 pos_term = solver->MakeProd(vars[i], coefs[i]);
2988 } else {
2989 LOG(FATAL) << "Should not be here";
2990 }
2991 }
2992 return solver->MakeLessOrEqual(pos_term, rhs);
2993 } else if (negatives == 1) {
2994 IntExpr* neg_term = nullptr;
2995 int64 rhs = upper_bound;
2996 for (int i = 0; i < size; ++i) {
2997 if (coefs[i] == 0 || vars[i]->Bound()) {
2998 rhs = CapSub(rhs, CapProd(coefs[i], vars[i]->Min()));
2999 } else if (coefs[i] > 0) {
3000 LOG(FATAL) << "Should not be here";
3001 } else {
3002 neg_term = solver->MakeProd(vars[i], -coefs[i]);
3003 }
3004 }
3005 return solver->MakeGreaterOrEqual(neg_term, -rhs);
3006 } else if (positives > 1) {
3007 std::vector<IntVar*> pos_terms;
3008 int64 rhs = upper_bound;
3009 for (int i = 0; i < size; ++i) {
3010 if (coefs[i] == 0 || vars[i]->Bound()) {
3011 rhs = CapSub(rhs, CapProd(coefs[i], vars[i]->Min()));
3012 } else if (coefs[i] > 0) {
3013 pos_terms.push_back(solver->MakeProd(vars[i], coefs[i])->Var());
3014 } else {
3015 LOG(FATAL) << "Should not be here";
3016 }
3017 }
3018 return solver->MakeSumLessOrEqual(pos_terms, rhs);
3019 } else if (negatives > 1) {
3020 std::vector<IntVar*> neg_terms;
3021 int64 rhs = upper_bound;
3022 for (int i = 0; i < size; ++i) {
3023 if (coefs[i] == 0 || vars[i]->Bound()) {
3024 rhs = CapSub(rhs, CapProd(coefs[i], vars[i]->Min()));
3025 } else if (coefs[i] > 0) {
3026 LOG(FATAL) << "Should not be here";
3027 } else {
3028 neg_terms.push_back(solver->MakeProd(vars[i], -coefs[i])->Var());
3029 }
3030 }
3031 return solver->MakeSumGreaterOrEqual(neg_terms, -rhs);
3032 }
3033 std::vector<IntVar*> terms;
3034 for (int i = 0; i < size; ++i) {
3035 terms.push_back(solver->MakeProd(vars[i], coefs[i])->Var());
3036 }
3037 return solver->MakeLessOrEqual(solver->MakeSum(terms), upper_bound);
3038}
3039
3040IntExpr* MakeSumArrayAux(Solver* const solver, const std::vector<IntVar*>& vars,
3041 int64 constant) {
3042 const int size = vars.size();
3043 DCHECK_GT(size, 2);
3044 int64 new_min = 0;
3045 int64 new_max = 0;
3046 for (int i = 0; i < size; ++i) {
3047 if (new_min != kint64min) {
3048 new_min = CapAdd(vars[i]->Min(), new_min);
3049 }
3050 if (new_max != kint64max) {
3051 new_max = CapAdd(vars[i]->Max(), new_max);
3052 }
3053 }
3054 IntExpr* const cache =
3055 solver->Cache()->FindVarArrayExpression(vars, ModelCache::VAR_ARRAY_SUM);
3056 if (cache != nullptr) {
3057 return solver->MakeSum(cache, constant);
3058 } else {
3059 const std::string name =
3060 absl::StrFormat("Sum([%s])", JoinNamePtr(vars, ", "));
3061 IntVar* const sum_var = solver->MakeIntVar(new_min, new_max, name);
3062 if (AreAllBooleans(vars)) {
3063 solver->AddConstraint(
3064 solver->RevAlloc(new SumBooleanEqualToVar(solver, vars, sum_var)));
3065 } else if (size <= solver->parameters().array_split_size()) {
3066 solver->AddConstraint(
3067 solver->RevAlloc(new SmallSumConstraint(solver, vars, sum_var)));
3068 } else {
3069 solver->AddConstraint(
3070 solver->RevAlloc(new SumConstraint(solver, vars, sum_var)));
3071 }
3072 solver->Cache()->InsertVarArrayExpression(sum_var, vars,
3073 ModelCache::VAR_ARRAY_SUM);
3074 return solver->MakeSum(sum_var, constant);
3075 }
3076}
3077
3078IntExpr* MakeSumAux(Solver* const solver, const std::vector<IntVar*>& vars,
3079 int64 constant) {
3080 const int size = vars.size();
3081 if (size == 0) {
3082 return solver->MakeIntConst(constant);
3083 } else if (size == 1) {
3084 return solver->MakeSum(vars[0], constant);
3085 } else if (size == 2) {
3086 return solver->MakeSum(solver->MakeSum(vars[0], vars[1]), constant);
3087 } else {
3088 return MakeSumArrayAux(solver, vars, constant);
3089 }
3090}
3091
3092IntExpr* MakeScalProdAux(Solver* solver, const std::vector<IntVar*>& vars,
3093 const std::vector<int64>& coefs, int64 constant) {
3094 if (AreAllOnes(coefs)) {
3095 return MakeSumAux(solver, vars, constant);
3096 }
3097
3098 const int size = vars.size();
3099 if (size == 0) {
3100 return solver->MakeIntConst(constant);
3101 } else if (size == 1) {
3102 return solver->MakeSum(solver->MakeProd(vars[0], coefs[0]), constant);
3103 } else if (size == 2) {
3104 if (coefs[0] > 0 && coefs[1] < 0) {
3105 return solver->MakeSum(
3106 solver->MakeDifference(solver->MakeProd(vars[0], coefs[0]),
3107 solver->MakeProd(vars[1], -coefs[1])),
3108 constant);
3109 } else if (coefs[0] < 0 && coefs[1] > 0) {
3110 return solver->MakeSum(
3111 solver->MakeDifference(solver->MakeProd(vars[1], coefs[1]),
3112 solver->MakeProd(vars[0], -coefs[0])),
3113 constant);
3114 } else {
3115 return solver->MakeSum(
3116 solver->MakeSum(solver->MakeProd(vars[0], coefs[0]),
3117 solver->MakeProd(vars[1], coefs[1])),
3118 constant);
3119 }
3120 } else {
3121 if (AreAllBooleans(vars)) {
3122 if (AreAllPositive(coefs)) {
3123 if (vars.size() > 8) {
3124 return solver->MakeSum(
3125 solver
3126 ->RegisterIntExpr(solver->RevAlloc(
3127 new PositiveBooleanScalProd(solver, vars, coefs)))
3128 ->Var(),
3129 constant);
3130 } else {
3131 return solver->MakeSum(
3132 solver->RegisterIntExpr(solver->RevAlloc(
3133 new PositiveBooleanScalProd(solver, vars, coefs))),
3134 constant);
3135 }
3136 } else {
3137 // If some coefficients are non-positive, partition coefficients in two
3138 // sets, one for the positive coefficients P and one for the negative
3139 // ones N.
3140 // Create two PositiveBooleanScalProd expressions, one on P (s1), the
3141 // other on Opposite(N) (s2).
3142 // The final expression is then s1 - s2.
3143 // If P is empty, the expression is Opposite(s2).
3144 std::vector<int64> positive_coefs;
3145 std::vector<int64> negative_coefs;
3146 std::vector<IntVar*> positive_coef_vars;
3147 std::vector<IntVar*> negative_coef_vars;
3148 for (int i = 0; i < size; ++i) {
3149 const int coef = coefs[i];
3150 if (coef > 0) {
3151 positive_coefs.push_back(coef);
3152 positive_coef_vars.push_back(vars[i]);
3153 } else if (coef < 0) {
3154 negative_coefs.push_back(-coef);
3155 negative_coef_vars.push_back(vars[i]);
3156 }
3157 }
3158 CHECK_GT(negative_coef_vars.size(), 0);
3159 IntExpr* const negatives =
3160 MakeScalProdAux(solver, negative_coef_vars, negative_coefs, 0);
3161 if (!positive_coef_vars.empty()) {
3162 IntExpr* const positives = MakeScalProdAux(solver, positive_coef_vars,
3163 positive_coefs, constant);
3164 return solver->MakeDifference(positives, negatives);
3165 } else {
3166 return solver->MakeDifference(constant, negatives);
3167 }
3168 }
3169 }
3170 }
3171 std::vector<IntVar*> terms;
3172 for (int i = 0; i < size; ++i) {
3173 terms.push_back(solver->MakeProd(vars[i], coefs[i])->Var());
3174 }
3175 return MakeSumArrayAux(solver, terms, constant);
3176}
3177
3178IntExpr* MakeScalProdFct(Solver* solver, const std::vector<IntVar*>& pre_vars,
3179 const std::vector<int64>& pre_coefs) {
3180 int64 constant = 0;
3181 std::vector<IntVar*> vars;
3182 std::vector<int64> coefs;
3183 DeepLinearize(solver, pre_vars, pre_coefs, &vars, &coefs, &constant);
3184
3185 if (vars.empty()) {
3186 return solver->MakeIntConst(constant);
3187 }
3188 // Can we simplify using some gcd computation.
3189 int64 gcd = std::abs(coefs[0]);
3190 for (int i = 1; i < coefs.size(); ++i) {
3191 gcd = MathUtil::GCD64(gcd, std::abs(coefs[i]));
3192 if (gcd == 1) {
3193 break;
3194 }
3195 }
3196 if (constant != 0 && gcd != 1) {
3197 gcd = MathUtil::GCD64(gcd, std::abs(constant));
3198 }
3199 if (gcd > 1) {
3200 for (int i = 0; i < coefs.size(); ++i) {
3201 coefs[i] /= gcd;
3202 }
3203 return solver->MakeProd(
3204 MakeScalProdAux(solver, vars, coefs, constant / gcd), gcd);
3205 }
3206 return MakeScalProdAux(solver, vars, coefs, constant);
3207}
3208
3209IntExpr* MakeSumFct(Solver* solver, const std::vector<IntVar*>& pre_vars) {
3210 absl::flat_hash_map<IntVar*, int64> variables_to_coefficients;
3211 ExprLinearizer linearizer(&variables_to_coefficients);
3212 for (int i = 0; i < pre_vars.size(); ++i) {
3213 linearizer.Visit(pre_vars[i], 1);
3214 }
3215 const int64 constant = linearizer.Constant();
3216 std::vector<IntVar*> vars;
3217 std::vector<int64> coefs;
3218 for (const auto& variable_to_coefficient : variables_to_coefficients) {
3219 if (variable_to_coefficient.second != 0) {
3220 vars.push_back(variable_to_coefficient.first);
3221 coefs.push_back(variable_to_coefficient.second);
3222 }
3223 }
3224 return MakeScalProdAux(solver, vars, coefs, constant);
3225}
3226} // namespace
3227
3228// ----- API -----
3229
3230IntExpr* Solver::MakeSum(const std::vector<IntVar*>& vars) {
3231 const int size = vars.size();
3232 if (size == 0) {
3233 return MakeIntConst(int64{0});
3234 } else if (size == 1) {
3235 return vars[0];
3236 } else if (size == 2) {
3237 return MakeSum(vars[0], vars[1]);
3238 } else {
3239 IntExpr* const cache =
3240 model_cache_->FindVarArrayExpression(vars, ModelCache::VAR_ARRAY_SUM);
3241 if (cache != nullptr) {
3242 return cache;
3243 } else {
3244 int64 new_min = 0;
3245 int64 new_max = 0;
3246 for (int i = 0; i < size; ++i) {
3247 if (new_min != kint64min) {
3248 new_min = CapAdd(vars[i]->Min(), new_min);
3249 }
3250 if (new_max != kint64max) {
3251 new_max = CapAdd(vars[i]->Max(), new_max);
3252 }
3253 }
3254 IntExpr* sum_expr = nullptr;
3255 const bool all_booleans = AreAllBooleans(vars);
3256 if (all_booleans) {
3257 const std::string name =
3258 absl::StrFormat("BooleanSum([%s])", JoinNamePtr(vars, ", "));
3259 sum_expr = MakeIntVar(new_min, new_max, name);
3260 AddConstraint(
3261 RevAlloc(new SumBooleanEqualToVar(this, vars, sum_expr->Var())));
3262 } else if (new_min != kint64min && new_max != kint64max) {
3263 sum_expr = MakeSumFct(this, vars);
3264 } else {
3265 const std::string name =
3266 absl::StrFormat("Sum([%s])", JoinNamePtr(vars, ", "));
3267 sum_expr = MakeIntVar(new_min, new_max, name);
3268 AddConstraint(
3269 RevAlloc(new SafeSumConstraint(this, vars, sum_expr->Var())));
3270 }
3271 model_cache_->InsertVarArrayExpression(sum_expr, vars,
3272 ModelCache::VAR_ARRAY_SUM);
3273 return sum_expr;
3274 }
3275 }
3276}
3277
3278IntExpr* Solver::MakeMin(const std::vector<IntVar*>& vars) {
3279 const int size = vars.size();
3280 if (size == 0) {
3281 LOG(WARNING) << "operations_research::Solver::MakeMin() was called with an "
3282 "empty list of variables. Was this intentional?";
3283 return MakeIntConst(kint64max);
3284 } else if (size == 1) {
3285 return vars[0];
3286 } else if (size == 2) {
3287 return MakeMin(vars[0], vars[1]);
3288 } else {
3289 IntExpr* const cache =
3290 model_cache_->FindVarArrayExpression(vars, ModelCache::VAR_ARRAY_MIN);
3291 if (cache != nullptr) {
3292 return cache;
3293 } else {
3294 if (AreAllBooleans(vars)) {
3295 IntVar* const new_var = MakeBoolVar();
3296 AddConstraint(RevAlloc(new ArrayBoolAndEq(this, vars, new_var)));
3297 model_cache_->InsertVarArrayExpression(new_var, vars,
3298 ModelCache::VAR_ARRAY_MIN);
3299 return new_var;
3300 } else {
3301 int64 new_min = kint64max;
3302 int64 new_max = kint64max;
3303 for (int i = 0; i < size; ++i) {
3304 new_min = std::min(new_min, vars[i]->Min());
3305 new_max = std::min(new_max, vars[i]->Max());
3306 }
3307 IntVar* const new_var = MakeIntVar(new_min, new_max);
3308 if (size <= parameters_.array_split_size()) {
3309 AddConstraint(RevAlloc(new SmallMinConstraint(this, vars, new_var)));
3310 } else {
3311 AddConstraint(RevAlloc(new MinConstraint(this, vars, new_var)));
3312 }
3313 model_cache_->InsertVarArrayExpression(new_var, vars,
3314 ModelCache::VAR_ARRAY_MIN);
3315 return new_var;
3316 }
3317 }
3318 }
3319}
3320
3321IntExpr* Solver::MakeMax(const std::vector<IntVar*>& vars) {
3322 const int size = vars.size();
3323 if (size == 0) {
3324 LOG(WARNING) << "operations_research::Solver::MakeMax() was called with an "
3325 "empty list of variables. Was this intentional?";
3326 return MakeIntConst(kint64min);
3327 } else if (size == 1) {
3328 return vars[0];
3329 } else if (size == 2) {
3330 return MakeMax(vars[0], vars[1]);
3331 } else {
3332 IntExpr* const cache =
3333 model_cache_->FindVarArrayExpression(vars, ModelCache::VAR_ARRAY_MAX);
3334 if (cache != nullptr) {
3335 return cache;
3336 } else {
3337 if (AreAllBooleans(vars)) {
3338 IntVar* const new_var = MakeBoolVar();
3339 AddConstraint(RevAlloc(new ArrayBoolOrEq(this, vars, new_var)));
3340 model_cache_->InsertVarArrayExpression(new_var, vars,
3341 ModelCache::VAR_ARRAY_MIN);
3342 return new_var;
3343 } else {
3344 int64 new_min = kint64min;
3345 int64 new_max = kint64min;
3346 for (int i = 0; i < size; ++i) {
3347 new_min = std::max(new_min, vars[i]->Min());
3348 new_max = std::max(new_max, vars[i]->Max());
3349 }
3350 IntVar* const new_var = MakeIntVar(new_min, new_max);
3351 if (size <= parameters_.array_split_size()) {
3352 AddConstraint(RevAlloc(new SmallMaxConstraint(this, vars, new_var)));
3353 } else {
3354 AddConstraint(RevAlloc(new MaxConstraint(this, vars, new_var)));
3355 }
3356 model_cache_->InsertVarArrayExpression(new_var, vars,
3357 ModelCache::VAR_ARRAY_MAX);
3358 return new_var;
3359 }
3360 }
3361 }
3362}
3363
3364Constraint* Solver::MakeMinEquality(const std::vector<IntVar*>& vars,
3365 IntVar* const min_var) {
3366 const int size = vars.size();
3367 if (size > 2) {
3368 if (AreAllBooleans(vars)) {
3369 return RevAlloc(new ArrayBoolAndEq(this, vars, min_var));
3370 } else if (size <= parameters_.array_split_size()) {
3371 return RevAlloc(new SmallMinConstraint(this, vars, min_var));
3372 } else {
3373 return RevAlloc(new MinConstraint(this, vars, min_var));
3374 }
3375 } else if (size == 2) {
3376 return MakeEquality(MakeMin(vars[0], vars[1]), min_var);
3377 } else if (size == 1) {
3378 return MakeEquality(vars[0], min_var);
3379 } else {
3380 LOG(WARNING) << "operations_research::Solver::MakeMinEquality() was called "
3381 "with an empty list of variables. Was this intentional?";
3382 return MakeEquality(min_var, kint64max);
3383 }
3384}
3385
3386Constraint* Solver::MakeMaxEquality(const std::vector<IntVar*>& vars,
3387 IntVar* const max_var) {
3388 const int size = vars.size();
3389 if (size > 2) {
3390 if (AreAllBooleans(vars)) {
3391 return RevAlloc(new ArrayBoolOrEq(this, vars, max_var));
3392 } else if (size <= parameters_.array_split_size()) {
3393 return RevAlloc(new SmallMaxConstraint(this, vars, max_var));
3394 } else {
3395 return RevAlloc(new MaxConstraint(this, vars, max_var));
3396 }
3397 } else if (size == 2) {
3398 return MakeEquality(MakeMax(vars[0], vars[1]), max_var);
3399 } else if (size == 1) {
3400 return MakeEquality(vars[0], max_var);
3401 } else {
3402 LOG(WARNING) << "operations_research::Solver::MakeMaxEquality() was called "
3403 "with an empty list of variables. Was this intentional?";
3404 return MakeEquality(max_var, kint64min);
3405 }
3406}
3407
3408Constraint* Solver::MakeSumLessOrEqual(const std::vector<IntVar*>& vars,
3409 int64 cst) {
3410 const int size = vars.size();
3411 if (cst == 1LL && AreAllBooleans(vars) && size > 2) {
3412 return RevAlloc(new SumBooleanLessOrEqualToOne(this, vars));
3413 } else {
3414 return MakeLessOrEqual(MakeSum(vars), cst);
3415 }
3416}
3417
3418Constraint* Solver::MakeSumGreaterOrEqual(const std::vector<IntVar*>& vars,
3419 int64 cst) {
3420 const int size = vars.size();
3421 if (cst == 1LL && AreAllBooleans(vars) && size > 2) {
3422 return RevAlloc(new SumBooleanGreaterOrEqualToOne(this, vars));
3423 } else {
3424 return MakeGreaterOrEqual(MakeSum(vars), cst);
3425 }
3426}
3427
3428Constraint* Solver::MakeSumEquality(const std::vector<IntVar*>& vars,
3429 int64 cst) {
3430 const int size = vars.size();
3431 if (size == 0) {
3432 return cst == 0 ? MakeTrueConstraint() : MakeFalseConstraint();
3433 }
3434 if (AreAllBooleans(vars) && size > 2) {
3435 if (cst == 1) {
3436 return RevAlloc(new SumBooleanEqualToOne(this, vars));
3437 } else if (cst < 0 || cst > size) {
3438 return MakeFalseConstraint();
3439 } else {
3440 return RevAlloc(new SumBooleanEqualToVar(this, vars, MakeIntConst(cst)));
3441 }
3442 } else {
3443 if (vars.size() == 1) {
3444 return MakeEquality(vars[0], cst);
3445 } else if (vars.size() == 2) {
3446 return MakeEquality(vars[0], MakeDifference(cst, vars[1]));
3447 }
3448 if (DetectSumOverflow(vars)) {
3449 return RevAlloc(new SafeSumConstraint(this, vars, MakeIntConst(cst)));
3450 } else if (size <= parameters_.array_split_size()) {
3451 return RevAlloc(new SmallSumConstraint(this, vars, MakeIntConst(cst)));
3452 } else {
3453 return RevAlloc(new SumConstraint(this, vars, MakeIntConst(cst)));
3454 }
3455 }
3456}
3457
3458Constraint* Solver::MakeSumEquality(const std::vector<IntVar*>& vars,
3459 IntVar* const var) {
3460 const int size = vars.size();
3461 if (size == 0) {
3462 return MakeEquality(var, Zero());
3463 }
3464 if (AreAllBooleans(vars) && size > 2) {
3465 return RevAlloc(new SumBooleanEqualToVar(this, vars, var));
3466 } else if (size == 0) {
3467 return MakeEquality(var, Zero());
3468 } else if (size == 1) {
3469 return MakeEquality(vars[0], var);
3470 } else if (size == 2) {
3471 return MakeEquality(MakeSum(vars[0], vars[1]), var);
3472 } else {
3473 if (DetectSumOverflow(vars)) {
3474 return RevAlloc(new SafeSumConstraint(this, vars, var));
3475 } else if (size <= parameters_.array_split_size()) {
3476 return RevAlloc(new SmallSumConstraint(this, vars, var));
3477 } else {
3478 return RevAlloc(new SumConstraint(this, vars, var));
3479 }
3480 }
3481}
3482
3483Constraint* Solver::MakeScalProdEquality(const std::vector<IntVar*>& vars,
3484 const std::vector<int64>& coefficients,
3485 int64 cst) {
3486 DCHECK_EQ(vars.size(), coefficients.size());
3487 return MakeScalProdEqualityFct(this, vars, coefficients, cst);
3488}
3489
3490Constraint* Solver::MakeScalProdEquality(const std::vector<IntVar*>& vars,
3491 const std::vector<int>& coefficients,
3492 int64 cst) {
3493 DCHECK_EQ(vars.size(), coefficients.size());
3494 return MakeScalProdEqualityFct(this, vars, ToInt64Vector(coefficients), cst);
3495}
3496
3497Constraint* Solver::MakeScalProdEquality(const std::vector<IntVar*>& vars,
3498 const std::vector<int64>& coefficients,
3499 IntVar* const target) {
3500 DCHECK_EQ(vars.size(), coefficients.size());
3501 return MakeScalProdEqualityVarFct(this, vars, coefficients, target);
3502}
3503
3504Constraint* Solver::MakeScalProdEquality(const std::vector<IntVar*>& vars,
3505 const std::vector<int>& coefficients,
3506 IntVar* const target) {
3507 DCHECK_EQ(vars.size(), coefficients.size());
3508 return MakeScalProdEqualityVarFct(this, vars, ToInt64Vector(coefficients),
3509 target);
3510}
3511
3512Constraint* Solver::MakeScalProdGreaterOrEqual(const std::vector<IntVar*>& vars,
3513 const std::vector<int64>& coeffs,
3514 int64 cst) {
3515 DCHECK_EQ(vars.size(), coeffs.size());
3516 return MakeScalProdGreaterOrEqualFct(this, vars, coeffs, cst);
3517}
3518
3519Constraint* Solver::MakeScalProdGreaterOrEqual(const std::vector<IntVar*>& vars,
3520 const std::vector<int>& coeffs,
3521 int64 cst) {
3522 DCHECK_EQ(vars.size(), coeffs.size());
3523 return MakeScalProdGreaterOrEqualFct(this, vars, ToInt64Vector(coeffs), cst);
3524}
3525
3526Constraint* Solver::MakeScalProdLessOrEqual(
3527 const std::vector<IntVar*>& vars, const std::vector<int64>& coefficients,
3528 int64 cst) {
3529 DCHECK_EQ(vars.size(), coefficients.size());
3530 return MakeScalProdLessOrEqualFct(this, vars, coefficients, cst);
3531}
3532
3533Constraint* Solver::MakeScalProdLessOrEqual(
3534 const std::vector<IntVar*>& vars, const std::vector<int>& coefficients,
3535 int64 cst) {
3536 DCHECK_EQ(vars.size(), coefficients.size());
3537 return MakeScalProdLessOrEqualFct(this, vars, ToInt64Vector(coefficients),
3538 cst);
3539}
3540
3541IntExpr* Solver::MakeScalProd(const std::vector<IntVar*>& vars,
3542 const std::vector<int64>& coefs) {
3543 DCHECK_EQ(vars.size(), coefs.size());
3544 return MakeScalProdFct(this, vars, coefs);
3545}
3546
3547IntExpr* Solver::MakeScalProd(const std::vector<IntVar*>& vars,
3548 const std::vector<int>& coefs) {
3549 DCHECK_EQ(vars.size(), coefs.size());
3550 return MakeScalProdFct(this, vars, ToInt64Vector(coefs));
3551}
3552} // namespace operations_research
int64 min
Definition: alldiff_cst.cc:138
int64 max
Definition: alldiff_cst.cc:139
#define CHECK(condition)
Definition: base/logging.h:495
#define DCHECK_NE(val1, val2)
Definition: base/logging.h:886
#define CHECK_LT(val1, val2)
Definition: base/logging.h:700
#define CHECK_EQ(val1, val2)
Definition: base/logging.h:697
#define CHECK_GT(val1, val2)
Definition: base/logging.h:702
#define DCHECK_GE(val1, val2)
Definition: base/logging.h:889
#define CHECK_NE(val1, val2)
Definition: base/logging.h:698
#define DCHECK_GT(val1, val2)
Definition: base/logging.h:890
#define DCHECK_LT(val1, val2)
Definition: base/logging.h:888
#define LOG(severity)
Definition: base/logging.h:420
#define DCHECK(condition)
Definition: base/logging.h:884
#define DCHECK_EQ(val1, val2)
Definition: base/logging.h:885
A constraint is the main modeling object.
The class IntExpr is the base of all integer expressions in constraint programming.
virtual IntVar * Var()=0
Creates a variable from the expression.
The class IntVar is a subset of IntExpr.
void SetValue(Solver *const s, const T &val)
SatParameters parameters
const std::string name
const Constraint * ct
int64 value
Rev< int64 > node_max
Definition: expr_array.cc:141
const std::vector< IntVar * > vars_
Definition: expr_array.cc:135
IntVar * var
Definition: expr_array.cc:1858
#define IS_TYPE(type, tag)
Definition: expr_array.cc:2381
Rev< int64 > node_min
Definition: expr_array.cc:140
int64 coef
Definition: expr_array.cc:1859
RevSwitch inactive_
Definition: expr_array.cc:1471
static const int64 kint64max
int64_t int64
static const int64 kint64min
const int WARNING
Definition: log_severity.h:31
const int FATAL
Definition: log_severity.h:32
The vehicle routing library lets one model and solve generic vehicle routing problems ranging from th...
int64 CapAdd(int64 x, int64 y)
std::string JoinNamePtr(const std::vector< T > &v, const std::string &separator)
Definition: string_array.h:52
Demon * MakeDelayedConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
bool AreAllNegative(const std::vector< T > &values)
int64 CapProd(int64 x, int64 y)
bool AreAllPositive(const std::vector< T > &values)
int64 CapSub(int64 x, int64 y)
bool AreAllBooleans(const std::vector< IntVar * > &vars)
Demon * MakeConstraintDemon1(Solver *const s, T *const ct, void(T::*method)(P), const std::string &name, P param1)
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
std::string JoinDebugStringPtr(const std::vector< T > &v, const std::string &separator)
Definition: string_array.h:45
std::vector< int64 > ToInt64Vector(const std::vector< int > &input)
Definition: utilities.cc:822
bool AreAllNull(const std::vector< T > &values)
bool AreAllOnes(const std::vector< T > &values)
bool AreAllBoundOrNull(const std::vector< IntVar * > &vars, const std::vector< T > &values)
Returns true if all the variables are assigned to a single value, or if their corresponding value is ...
int index
Definition: pack.cc:508
int64 coefficient
std::vector< double > coefficients
IntervalVar *const target_var_