1 /** Extend std.algorithm.setopts to also operate on set- and map-like
2     containers/ranges.
3 
4     See_Also: http://forum.dlang.org/post/nvd09v$24e9$1@digitalmars.com
5 */
6 module nxt.setops_ex;
7 
8 // version = show;
9 
10 /** Specialization for `std.algorithm.setopts.setUnion` for AA. */
11 auto setUnionUpdate(T1, T2)(T1 a, T2 b) @trusted
12 if (isAA!T1 &&
13     isAA!T2)
14 {
15     if (a.length < b.length)
16     {
17         return setUnionHelper(a, b);
18     }
19     else
20     {
21         return setUnionHelper(b, a);
22     }
23 }
24 
25 /** Helper function for `setUnionUpdate` that assumes `small` has shorter length than
26     `large` .
27 */
28 private static auto setUnionHelper(Small, Large)(const Small small, Large large)
29 {
30     Large united = large.dup;   // TODO: this shallow copy prevents large from being `const`
31     foreach (const ref e; small.byKeyValue)
32     {
33         if (auto hitPtr = e.key in large)
34         {
35             (*hitPtr) = e.value; // TODO: this potentially changes the value of
36         }
37         else
38         {
39             united[e.key] = e.value;
40         }
41     }
42     return united;
43 }
44 
45 /** Is `true` iff `Map` is map-like container, that is provides membership
46  * checking via the `in` operator or `contains`.
47  *
48  * TODO: Move to Phobos std.traits
49  */
50 template isAA(Map)
51 {
52     import std.traits : isAssociativeArray;
53     enum isAA = isAssociativeArray!Map; // TODO: check if in operator returns reference to value
54 }
55 
56 /// union of associative array (via keys)
57 @safe pure unittest
58 {
59     alias Map = string[int];
60 
61     Map x = [0 : "a", 1 : "b"];
62     Map y = [2 : "c"];
63 
64     Map c = [0 : "a", 1 : "b", 2 : "c"];
65 
66     // test associativity
67     assert(setUnionUpdate(x, y) == c);
68     assert(setUnionUpdate(y, x) == c);
69 }
70 
71 import std.traits : CommonType;
72 import std.range.primitives;
73 import std.meta : allSatisfy, staticMap;
74 import std.functional : binaryFun;
75 import std.range : SearchPolicy;
76 import nxt.range_ex : haveCommonElementType;
77 
78 /** Intersection of two or more ranges of type `Rs`.
79  *
80  * See_Also: https://forum.dlang.org/post/puwffthbqaktlqnourrs@forum.dlang.org
81  */
82 private struct SetIntersectionFast(alias less = "a < b",
83                                    SearchPolicy preferredSearchPolicy = SearchPolicy.gallop,
84                                    Rs...)
85 if (Rs.length >= 2 &&
86     allSatisfy!(isInputRange, Rs) &&
87     haveCommonElementType!Rs)
88 {
89 private:
90     Rs _inputs;
91     alias comp = binaryFun!less;
92     alias ElementType = CommonType!(staticMap!(.ElementType, Rs));
93 
94     // Positions to the first elements that are all equal
95     void adjustPosition()
96     {
97         if (empty) return;
98 
99         auto compsLeft = Rs.length; // number of compares left
100         static if (Rs.length > 1) while (true)
101         {
102             foreach (i, ref r; _inputs)
103             {
104                 alias next = _inputs[(i + 1) % Rs.length]; // requires copying of range
105 
106                 // TODO: Use upperBound only when next.length / r.length > 12
107 
108                 import std.range.primitives : isRandomAccessRange;
109                 static if (allSatisfy!(isRandomAccessRange, typeof(next)))
110                 {
111                     import std.range : assumeSorted;
112 
113                     // TODO: remove need for this hack
114                     static if (less == "a < b")
115                     {
116                         enum lessEq = "a <= b";
117                     }
118                     else static if (less == "a > b")
119                     {
120                         enum lessEq = "a >= b";
121                     }
122 
123                     // TODO: can we merge thsse two lines two one single assignment from nextUpperBound to next
124                     auto nextUpperBound = next.assumeSorted!lessEq.upperBound!preferredSearchPolicy(r.front);
125                     next = next[$ - nextUpperBound.length .. $];
126 
127                     if (next.empty)
128                     {
129                         return; // next became empty, so everything becomes empty
130                     }
131                     else if (next.front != r.front)
132                     {
133                         compsLeft = Rs.length; // we need to start counting comparing again starting with next.front
134                     }
135                 }
136                 else
137                 {
138                     if (comp(next.front, r.front))
139                     {
140                         do
141                         {
142                             next.popFront();
143                             if (next.empty) return;
144                         }
145                         while (comp(next.front, r.front));
146                         compsLeft = Rs.length;
147                     }
148                 }
149                 if (--compsLeft == 0) return; // count down, and if we have made Rs.length iterations we are compsLeft finding a common front element
150             }
151         }
152     }
153 
154 public:
155     ///
156     this(Rs inputs)
157     {
158         import std.functional : forward;
159         this._inputs = forward!inputs; // TODO: remove `forward` when compiler does it for us
160         // position to the first element
161         adjustPosition();
162     }
163 
164     ///
165     @property bool empty()
166     {
167         foreach (ref r; _inputs)
168         {
169             if (r.empty) return true;
170         }
171         return false;
172     }
173 
174     ///
175     void popFront()
176     {
177         assert(!empty);
178         static if (Rs.length > 1) foreach (i, ref r; _inputs)
179         {
180             alias next = _inputs[(i + 1) % Rs.length];
181             assert(!comp(r.front, next.front));
182         }
183 
184         foreach (ref r; _inputs)
185         {
186             r.popFront();
187         }
188         adjustPosition();
189     }
190 
191     ///
192     @property ElementType front()
193     {
194         assert(!empty);
195         return _inputs[0].front;
196     }
197 
198     static if (allSatisfy!(isForwardRange, Rs))
199     {
200         ///
201         @property SetIntersectionFast save()
202         {
203             auto ret = this;
204             foreach (i, ref r; _inputs)
205             {
206                 ret._inputs[i] = r.save;
207             }
208             return ret;
209         }
210     }
211 }
212 
213 import core.internal.traits : Unqual;
214 
215 auto assumeMoveableSorted(alias pred = "a < b", R)(R r)
216 if (isInputRange!(Unqual!R))
217 {
218     import core.lifetime : move;
219     return MoveableSortedRange!(Unqual!R, pred)(move(r)); // TODO: remove `move` when compiler does it for us
220 }
221 
222 /** Get intersection of `ranges`.
223  *
224  * See_Also: https://forum.dlang.org/post/puwffthbqaktlqnourrs@forum.dlang.org
225  */
226 MoveableSortedRange!(SetIntersectionFast!(less, preferredSearchPolicy, Rs))
227 setIntersectionFast(alias less = "a < b",
228                     SearchPolicy preferredSearchPolicy = SearchPolicy.gallop,
229                     Rs...)(Rs ranges)
230 if (Rs.length >= 2 &&
231     allSatisfy!(isInputRange, Rs) &&
232     haveCommonElementType!Rs)
233 {
234     // TODO: Remove need for these switch cases if this can be fixed:
235     // http://forum.dlang.org/post/pknonazfniihvpicxbld@forum.dlang.org
236     import std.range : assumeSorted;
237     static if (Rs.length == 2)
238     {
239         import core.lifetime : move;
240         return assumeMoveableSorted(SetIntersectionFast!(less,
241                                                          preferredSearchPolicy,
242                                                          Rs)(move(ranges[0]), // TODO: remove `move` when compiler does it for us
243                                                              move(ranges[1]))); // TODO: remove `move` when compiler does it for us
244     }
245     else
246     {
247         import std.functional : forward;
248         return assumeMoveableSorted(SetIntersectionFast!(less,
249                                                          preferredSearchPolicy,
250                                                          Rs)(forward!ranges)); // TODO: remove `forward` when compiler does it for us
251     }
252 }
253 
254 @safe unittest
255 {
256     import std.algorithm.comparison : equal;
257     import std.algorithm.sorting : sort;
258     import std.algorithm.setops : setIntersection;
259     import nxt.random_ex : randInPlaceWithElementRange;
260     import nxt.dynamic_array : DynamicArray;
261     import nxt.algorithm_ex : collect;
262 
263     alias E = ulong;
264     alias A = DynamicArray!E;
265 
266     auto a0 = A();
267     auto a1 = A(1);
268 
269     enum less = "a < b";
270 
271     auto s0 = setIntersectionFast!(less)(a0[], a0[]);
272     assert(s0.equal(a0[]));
273 
274     auto s1 = setIntersectionFast!(less)(a1[], a1[]);
275     assert(s1.equal(a1[]));
276 
277     immutable smallTestLength = 1000;
278     immutable factor = 12; // this is the magical limit on my laptop when performance of `upperBound` beats standard implementation
279     immutable largeTestLength = factor*smallTestLength;
280     E elementLow = 0;
281     E elementHigh = 10_000_000;
282     auto x = A.withLength(smallTestLength);
283     auto y = A.withLength(largeTestLength);
284 
285     x[].randInPlaceWithElementRange(elementLow, elementHigh);
286     y[].randInPlaceWithElementRange(elementLow, elementHigh);
287 
288     sort(x[]);
289     sort(y[]);
290 
291     // associative
292     assert(equal(setIntersectionFast!(less)(x[], y[]),
293                  setIntersectionFast!(less)(y[], x[])));
294 
295     // same as current
296     assert(equal(setIntersection!(less)(x[], y[]),
297                  setIntersectionFast!(less)(x[], y[])));
298 
299     void testSetIntersection()
300     {
301         auto z = setIntersection!(less)(x[], y[]).collect!A;
302     }
303 
304     void testSetIntersectionNew()
305     {
306         auto z = setIntersectionFast!(less)(x[], y[]).collect!A;
307     }
308 
309     import std.datetime.stopwatch : benchmark;
310     import core.time : Duration;
311     immutable testCount = 10;
312     auto r = benchmark!(testSetIntersection,
313                         testSetIntersectionNew)(testCount);
314     import std.stdio : writeln;
315     import std.conv : to;
316 
317     version(show)
318     {
319         writeln("old testSetIntersection: ", to!Duration(r[0]));
320         writeln("new testSetIntersection: ", to!Duration(r[1]));
321     }
322 }
323 
324 @safe pure nothrow unittest
325 {
326     import std.algorithm.comparison : equal;
327     enum less = "a < b";
328     auto si = setIntersectionFast!(less)([1, 2, 3],
329                                          [1, 2, 3]);
330     const sic = si.save();
331     assert(si.equal([1, 2, 3]));
332 }
333 
334 // TODO: remove this `MoveableSortedRange` and replace with Phobos' `SortedRange` in this buffer
335 struct MoveableSortedRange(Range, alias pred = "a < b")
336 if (isInputRange!Range)
337 {
338     import std.functional : binaryFun;
339 
340     private alias predFun = binaryFun!pred;
341     private bool geq(L, R)(L lhs, R rhs)
342     {
343         return !predFun(lhs, rhs);
344     }
345     private bool gt(L, R)(L lhs, R rhs)
346     {
347         return predFun(rhs, lhs);
348     }
349     private Range _input;
350 
351     // Undocummented because a clearer way to invoke is by calling
352     // assumeSorted.
353     this(Range input)
354     out
355     {
356         // moved out of the body as a workaround for Issue 12661
357         dbgVerifySorted();
358     }
359     do
360     {
361         import core.lifetime : move;
362         this._input = move(input); // TODO
363     }
364 
365     // Assertion only.
366     private void dbgVerifySorted()
367     {
368         if (!__ctfe)
369         debug
370         {
371             static if (isRandomAccessRange!Range && hasLength!Range)
372             {
373                 import core.bitop : bsr;
374                 import std.algorithm.sorting : isSorted;
375 
376                 // Check the sortedness of the input
377                 if (this._input.length < 2) return;
378 
379                 immutable size_t msb = bsr(this._input.length) + 1;
380                 assert(msb > 0 && msb <= this._input.length);
381                 immutable step = this._input.length / msb;
382                 auto st = stride(this._input, step);
383 
384                 assert(isSorted!pred(st), "Range is not sorted");
385             }
386         }
387     }
388 
389     /// Range primitives.
390     @property bool empty()             //const
391     {
392         return this._input.empty;
393     }
394 
395     /// Ditto
396     static if (isForwardRange!Range)
397     @property auto save()
398     {
399         // Avoid the constructor
400         typeof(this) result = this;
401         result._input = _input.save;
402         return result;
403     }
404 
405     /// Ditto
406     @property auto ref front()
407     {
408         return _input.front;
409     }
410 
411     /// Ditto
412     void popFront()
413     {
414         _input.popFront();
415     }
416 
417     /// Ditto
418     static if (isBidirectionalRange!Range)
419     {
420         @property auto ref back()
421         {
422             return _input.back;
423         }
424 
425         /// Ditto
426         void popBack()
427         {
428             _input.popBack();
429         }
430     }
431 
432     /// Ditto
433     static if (isRandomAccessRange!Range)
434         auto ref opIndex(size_t i)
435         {
436             return _input[i];
437         }
438 
439     /// Ditto
440     static if (hasSlicing!Range)
441         auto opSlice(size_t a, size_t b)
442         {
443             assert(
444                 a <= b,
445                 "Attempting to slice a SortedRange with a larger first argument than the second."
446             );
447             typeof(this) result = this;
448             result._input = _input[a .. b];// skip checking
449             return result;
450         }
451 
452     /// Ditto
453     static if (hasLength!Range)
454     {
455         @property size_t length()          //const
456         {
457             return _input.length;
458         }
459         alias opDollar = length;
460     }
461 
462 /**
463    Releases the controlled range and returns it.
464 */
465     auto release()
466     {
467         import core.lifetime : move;
468         return move(_input);
469     }
470 
471     // Assuming a predicate "test" that returns 0 for a left portion
472     // of the range and then 1 for the rest, returns the index at
473     // which the first 1 appears. Used internally by the search routines.
474     private size_t getTransitionIndex(SearchPolicy sp, alias test, V)(V v)
475     if (sp == SearchPolicy.binarySearch && isRandomAccessRange!Range && hasLength!Range)
476     {
477         size_t first = 0, count = _input.length;
478         while (count > 0)
479         {
480             immutable step = count / 2, it = first + step;
481             if (!test(_input[it], v))
482             {
483                 first = it + 1;
484                 count -= step + 1;
485             }
486             else
487             {
488                 count = step;
489             }
490         }
491         return first;
492     }
493 
494     // Specialization for trot and gallop
495     private size_t getTransitionIndex(SearchPolicy sp, alias test, V)(V v)
496     if ((sp == SearchPolicy.trot || sp == SearchPolicy.gallop)
497         && isRandomAccessRange!Range)
498     {
499         if (empty || test(front, v)) return 0;
500         immutable count = length;
501         if (count == 1) return 1;
502         size_t below = 0, above = 1, step = 2;
503         while (!test(_input[above], v))
504         {
505             // Still too small, update below and increase gait
506             below = above;
507             immutable next = above + step;
508             if (next >= count)
509             {
510                 // Overshot - the next step took us beyond the end. So
511                 // now adjust next and simply exit the loop to do the
512                 // binary search thingie.
513                 above = count;
514                 break;
515             }
516             // Still in business, increase step and continue
517             above = next;
518             static if (sp == SearchPolicy.trot)
519                 ++step;
520             else
521                 step <<= 1;
522         }
523         return below + this[below .. above].getTransitionIndex!(
524             SearchPolicy.binarySearch, test, V)(v);
525     }
526 
527     // Specialization for trotBackwards and gallopBackwards
528     private size_t getTransitionIndex(SearchPolicy sp, alias test, V)(V v)
529     if ((sp == SearchPolicy.trotBackwards || sp == SearchPolicy.gallopBackwards)
530         && isRandomAccessRange!Range)
531     {
532         immutable count = length;
533         if (empty || !test(back, v)) return count;
534         if (count == 1) return 0;
535         size_t below = count - 2, above = count - 1, step = 2;
536         while (test(_input[below], v))
537         {
538             // Still too large, update above and increase gait
539             above = below;
540             if (below < step)
541             {
542                 // Overshot - the next step took us beyond the end. So
543                 // now adjust next and simply fall through to do the
544                 // binary search thingie.
545                 below = 0;
546                 break;
547             }
548             // Still in business, increase step and continue
549             below -= step;
550             static if (sp == SearchPolicy.trot)
551                 ++step;
552             else
553                 step <<= 1;
554         }
555         return below + this[below .. above].getTransitionIndex!(
556             SearchPolicy.binarySearch, test, V)(v);
557     }
558 
559 // lowerBound
560 /**
561    This function uses a search with policy $(D sp) to find the
562    largest left subrange on which $(D pred(x, value)) is `true` for
563    all $(D x) (e.g., if $(D pred) is "less than", returns the portion of
564    the range with elements strictly smaller than `value`). The search
565    schedule and its complexity are documented in
566    $(LREF SearchPolicy).  See also STL's
567    $(HTTP sgi.com/tech/stl/lower_bound.html, lower_bound).
568 */
569     auto lowerBound(SearchPolicy sp = SearchPolicy.binarySearch, V)(V value)
570     if (isTwoWayCompatible!(predFun, ElementType!Range, V)
571          && hasSlicing!Range)
572     {
573         return this[0 .. getTransitionIndex!(sp, geq)(value)];
574     }
575 
576 // upperBound
577 /**
578 This function searches with policy $(D sp) to find the largest right
579 subrange on which $(D pred(value, x)) is `true` for all $(D x)
580 (e.g., if $(D pred) is "less than", returns the portion of the range
581 with elements strictly greater than `value`). The search schedule
582 and its complexity are documented in $(LREF SearchPolicy).
583 
584 For ranges that do not offer random access, $(D SearchPolicy.linear)
585 is the only policy allowed (and it must be specified explicitly lest it exposes
586 user code to unexpected inefficiencies). For random-access searches, all
587 policies are allowed, and $(D SearchPolicy.binarySearch) is the default.
588 
589 See_Also: STL's $(HTTP sgi.com/tech/stl/lower_bound.html,upper_bound).
590 */
591     auto upperBound(SearchPolicy sp = SearchPolicy.binarySearch, V)(V value)
592     if (isTwoWayCompatible!(predFun, ElementType!Range, V))
593     {
594         static assert(hasSlicing!Range || sp == SearchPolicy.linear,
595             "Specify SearchPolicy.linear explicitly for "
596             ~ typeof(this).stringof);
597         static if (sp == SearchPolicy.linear)
598         {
599             for (; !_input.empty && !predFun(value, _input.front);
600                  _input.popFront())
601             {
602             }
603             return this;
604         }
605         else
606         {
607             return this[getTransitionIndex!(sp, gt)(value) .. length];
608         }
609     }
610 
611 
612 // equalRange
613 /**
614    Returns the subrange containing all elements $(D e) for which both $(D
615    pred(e, value)) and $(D pred(value, e)) evaluate to $(D false) (e.g.,
616    if $(D pred) is "less than", returns the portion of the range with
617    elements equal to `value`). Uses a classic binary search with
618    interval halving until it finds a value that satisfies the condition,
619    then uses $(D SearchPolicy.gallopBackwards) to find the left boundary
620    and $(D SearchPolicy.gallop) to find the right boundary. These
621    policies are justified by the fact that the two boundaries are likely
622    to be near the first found value (i.e., equal ranges are relatively
623    small). Completes the entire search in $(BIGOH log(n)) time. See also
624    STL's $(HTTP sgi.com/tech/stl/equal_range.html, equal_range).
625 */
626     auto equalRange(V)(V value)
627     if (isTwoWayCompatible!(predFun, ElementType!Range, V)
628         && isRandomAccessRange!Range)
629     {
630         size_t first = 0, count = _input.length;
631         while (count > 0)
632         {
633             immutable step = count / 2;
634             auto it = first + step;
635             if (predFun(_input[it], value))
636             {
637                 // Less than value, bump left bound up
638                 first = it + 1;
639                 count -= step + 1;
640             }
641             else if (predFun(value, _input[it]))
642             {
643                 // Greater than value, chop count
644                 count = step;
645             }
646             else
647             {
648                 // Equal to value, do binary searches in the
649                 // leftover portions
650                 // Gallop towards the left end as it's likely nearby
651                 immutable left = first
652                     + this[first .. it]
653                     .lowerBound!(SearchPolicy.gallopBackwards)(value).length;
654                 first += count;
655                 // Gallop towards the right end as it's likely nearby
656                 immutable right = first
657                     - this[it + 1 .. first]
658                     .upperBound!(SearchPolicy.gallop)(value).length;
659                 return this[left .. right];
660             }
661         }
662         return this.init;
663     }
664 
665 // trisect
666 /**
667 Returns a tuple $(D r) such that $(D r[0]) is the same as the result
668 of $(D lowerBound(value)), $(D r[1]) is the same as the result of $(D
669 equalRange(value)), and $(D r[2]) is the same as the result of $(D
670 upperBound(value)). The call is faster than computing all three
671 separately. Uses a search schedule similar to $(D
672 equalRange). Completes the entire search in $(BIGOH log(n)) time.
673 */
674     auto trisect(V)(V value)
675     if (isTwoWayCompatible!(predFun, ElementType!Range, V)
676         && isRandomAccessRange!Range && hasLength!Range)
677     {
678         import std.typecons : tuple;
679         size_t first = 0, count = _input.length;
680         while (count > 0)
681         {
682             immutable step = count / 2;
683             auto it = first + step;
684             if (predFun(_input[it], value))
685             {
686                 // Less than value, bump left bound up
687                 first = it + 1;
688                 count -= step + 1;
689             }
690             else if (predFun(value, _input[it]))
691             {
692                 // Greater than value, chop count
693                 count = step;
694             }
695             else
696             {
697                 // Equal to value, do binary searches in the
698                 // leftover portions
699                 // Gallop towards the left end as it's likely nearby
700                 immutable left = first
701                     + this[first .. it]
702                     .lowerBound!(SearchPolicy.gallopBackwards)(value).length;
703                 first += count;
704                 // Gallop towards the right end as it's likely nearby
705                 immutable right = first
706                     - this[it + 1 .. first]
707                     .upperBound!(SearchPolicy.gallop)(value).length;
708                 return tuple(this[0 .. left], this[left .. right],
709                         this[right .. length]);
710             }
711         }
712         // No equal element was found
713         return tuple(this[0 .. first], this.init, this[first .. length]);
714     }
715 
716 // contains
717 /**
718 Returns `true` if and only if `value` can be found in $(D
719 range), which is assumed to be sorted. Performs $(BIGOH log(r.length))
720 evaluations of $(D pred). See also STL's $(HTTP
721 sgi.com/tech/stl/binary_search.html, binary_search).
722  */
723 
724     bool contains(V)(V value)
725     if (isRandomAccessRange!Range)
726     {
727         if (empty) return false;
728         immutable i = getTransitionIndex!(SearchPolicy.binarySearch, geq)(value);
729         if (i >= length) return false;
730         return !predFun(value, _input[i]);
731     }
732 
733 // groupBy
734 /**
735 Returns a range of subranges of elements that are equivalent according to the
736 sorting relation.
737  */
738     auto groupBy()()
739     {
740         import std.algorithm.iteration : chunkBy;
741         return _input.chunkBy!((a, b) => !predFun(a, b) && !predFun(b, a));
742     }
743 }