1 /** Extend std.algorithm.setopts to also operate on set- and map-like
2     containers/ranges.
3 
4     See_Also: http://forum.dlang.org/post/nvd09v$24e9$1@digitalmars.com
5 */
6 module nxt.setops_ex;
7 
8 // version = show;
9 
10 /** Specialization for `std.algorithm.setopts.setUnion` for AA. */
11 auto setUnionUpdate(T1, T2)(T1 a, T2 b) @trusted
12 if (isAA!T1 &&
13     isAA!T2)
14 {
15     if (a.length < b.length)
16     {
17         return setUnionHelper(a, b);
18     }
19     else
20     {
21         return setUnionHelper(b, a);
22     }
23 }
24 
25 /** Helper function for `setUnionUpdate` that assumes `small` has shorter length than
26     `large` .
27 */
28 private static auto setUnionHelper(Small, Large)(const Small small, Large large)
29 {
30     Large united = large.dup;   // TODO this shallow copy prevents large from being `const`
31     foreach (const ref e; small.byKeyValue)
32     {
33         if (auto hitPtr = e.key in large)
34         {
35             (*hitPtr) = e.value; // TODO this potentially changes the value of
36         }
37         else
38         {
39             united[e.key] = e.value;
40         }
41     }
42     return united;
43 }
44 
45 /** Is `true` iff `Set` is set-like container, that is provides membership
46  * checking via the `in` operator or `contains`.
47  *
48  * TODO Move to Phobos std.traits
49  */
50 template hasContains(Set)
51 {
52     import std.traits : hasMember;
53     enum isSetOf = hasMember!(Set, "contains"); // TODO extend to check `in` operator aswell
54 }
55 
56 /** Is `true` iff `Map` is map-like container, that is provides membership
57  * checking via the `in` operator or `contains`.
58  *
59  * TODO Move to Phobos std.traits
60  */
61 template isAA(Map)
62 {
63     import std.traits : isAssociativeArray;
64     enum isAA = isAssociativeArray!Map; // TODO check if in operator returns reference to value
65 }
66 
67 version(unittest)
68 {
69     import std.algorithm.comparison : equal;
70     import nxt.dbgio : dbg;
71 }
72 
73 /// union of associative array (via keys)
74 @safe pure unittest
75 {
76     alias Map = string[int];
77 
78     Map x = [0 : "a", 1 : "b"];
79     Map y = [2 : "c"];
80 
81     Map c = [0 : "a", 1 : "b", 2 : "c"];
82 
83     // test associativity
84     assert(setUnionUpdate(x, y) == c);
85     assert(setUnionUpdate(y, x) == c);
86 }
87 
88 import std.traits : CommonType;
89 import std.range.primitives;
90 import std.meta : allSatisfy, staticMap;
91 import std.functional : binaryFun;
92 import std.range : SearchPolicy;
93 import nxt.range_ex : haveCommonElementType;
94 
95 /** Intersection of two or more ranges of type `Rs`.
96  *
97  * See_Also: https://forum.dlang.org/post/puwffthbqaktlqnourrs@forum.dlang.org
98  */
99 private struct SetIntersectionFast(alias less = "a < b",
100                                    SearchPolicy preferredSearchPolicy = SearchPolicy.gallop,
101                                    Rs...)
102 if (Rs.length >= 2 &&
103     allSatisfy!(isInputRange, Rs) &&
104     haveCommonElementType!Rs)
105 {
106 private:
107     Rs _inputs;
108     alias comp = binaryFun!less;
109     alias ElementType = CommonType!(staticMap!(.ElementType, Rs));
110 
111     // Positions to the first elements that are all equal
112     void adjustPosition()
113     {
114         if (empty) return;
115 
116         auto compsLeft = Rs.length; // number of compares left
117         static if (Rs.length > 1) while (true)
118         {
119             foreach (i, ref r; _inputs)
120             {
121                 alias next = _inputs[(i + 1) % Rs.length]; // requires copying of range
122 
123                 // TODO Use upperBound only when next.length / r.length > 12
124 
125                 import std.range.primitives : isRandomAccessRange;
126                 static if (allSatisfy!(isRandomAccessRange, typeof(next)))
127                 {
128                     import std.range : assumeSorted;
129 
130                     // TODO remove need for this hack
131                     static if (less == "a < b")
132                     {
133                         enum lessEq = "a <= b";
134                     }
135                     else static if (less == "a > b")
136                     {
137                         enum lessEq = "a >= b";
138                     }
139 
140                     // TODO can we merge thsse two lines two one single assignment from nextUpperBound to next
141                     auto nextUpperBound = next.assumeSorted!lessEq.upperBound!preferredSearchPolicy(r.front);
142                     next = next[$ - nextUpperBound.length .. $];
143 
144                     if (next.empty)
145                     {
146                         return; // next became empty, so everything becomes empty
147                     }
148                     else if (next.front != r.front)
149                     {
150                         compsLeft = Rs.length; // we need to start counting comparing again starting with next.front
151                     }
152                 }
153                 else
154                 {
155                     if (comp(next.front, r.front))
156                     {
157                         do
158                         {
159                             next.popFront();
160                             if (next.empty) return;
161                         }
162                         while (comp(next.front, r.front));
163                         compsLeft = Rs.length;
164                     }
165                 }
166                 if (--compsLeft == 0) return; // count down, and if we have made Rs.length iterations we are compsLeft finding a common front element
167             }
168         }
169     }
170 
171 public:
172     ///
173     this(Rs inputs)
174     {
175         import std.functional : forward;
176         this._inputs = forward!inputs; // TODO remove `forward` when compiler does it for us
177         // position to the first element
178         adjustPosition();
179     }
180 
181     ///
182     @property bool empty()
183     {
184         foreach (ref r; _inputs)
185         {
186             if (r.empty) return true;
187         }
188         return false;
189     }
190 
191     ///
192     void popFront()
193     {
194         assert(!empty);
195         static if (Rs.length > 1) foreach (i, ref r; _inputs)
196         {
197             alias next = _inputs[(i + 1) % Rs.length];
198             assert(!comp(r.front, next.front));
199         }
200 
201         foreach (ref r; _inputs)
202         {
203             r.popFront();
204         }
205         adjustPosition();
206     }
207 
208     ///
209     @property ElementType front()
210     {
211         assert(!empty);
212         return _inputs[0].front;
213     }
214 
215     static if (allSatisfy!(isForwardRange, Rs))
216     {
217         ///
218         @property SetIntersectionFast save()
219         {
220             auto ret = this;
221             foreach (i, ref r; _inputs)
222             {
223                 ret._inputs[i] = r.save;
224             }
225             return ret;
226         }
227     }
228 }
229 
230 import core.internal.traits : Unqual;
231 
232 auto assumeMoveableSorted(alias pred = "a < b", R)(R r)
233 if (isInputRange!(Unqual!R))
234 {
235     import core.lifetime : move;
236     return MoveableSortedRange!(Unqual!R, pred)(move(r)); // TODO remove `move` when compiler does it for us
237 }
238 
239 /** Get intersection of `ranges`.
240  *
241  * See_Also: https://forum.dlang.org/post/puwffthbqaktlqnourrs@forum.dlang.org
242  */
243 MoveableSortedRange!(SetIntersectionFast!(less, preferredSearchPolicy, Rs))
244 setIntersectionFast(alias less = "a < b",
245                     SearchPolicy preferredSearchPolicy = SearchPolicy.gallop,
246                     Rs...)(Rs ranges)
247 if (Rs.length >= 2 &&
248     allSatisfy!(isInputRange, Rs) &&
249     haveCommonElementType!Rs)
250 {
251     // TODO Remove need for these switch cases if this can be fixed:
252     // http://forum.dlang.org/post/pknonazfniihvpicxbld@forum.dlang.org
253     import std.range : assumeSorted;
254     static if (Rs.length == 2)
255     {
256         import core.lifetime : move;
257         return assumeMoveableSorted(SetIntersectionFast!(less,
258                                                          preferredSearchPolicy,
259                                                          Rs)(move(ranges[0]), // TODO remove `move` when compiler does it for us
260                                                              move(ranges[1]))); // TODO remove `move` when compiler does it for us
261     }
262     else
263     {
264         import std.functional : forward;
265         return assumeMoveableSorted(SetIntersectionFast!(less,
266                                                          preferredSearchPolicy,
267                                                          Rs)(forward!ranges)); // TODO remove `forward` when compiler does it for us
268     }
269 }
270 
271 @safe unittest
272 {
273     import std.algorithm.sorting : sort;
274     import std.algorithm.setops : setIntersection;
275     import nxt.random_ex : randInPlaceWithElementRange;
276     import nxt.dynamic_array : DynamicArray;
277     import nxt.algorithm_ex : collect;
278 
279     alias E = ulong;
280     alias A = DynamicArray!E;
281 
282     auto a0 = A();
283     auto a1 = A(1);
284 
285     enum less = "a < b";
286 
287     auto s0 = setIntersectionFast!(less)(a0[], a0[]);
288     assert(s0.equal(a0[]));
289 
290     auto s1 = setIntersectionFast!(less)(a1[], a1[]);
291     assert(s1.equal(a1[]));
292 
293     immutable smallTestLength = 1000;
294     immutable factor = 12; // this is the magical limit on my laptop when performance of `upperBound` beats standard implementation
295     immutable largeTestLength = factor*smallTestLength;
296     E elementLow = 0;
297     E elementHigh = 10_000_000;
298     auto x = A.withLength(smallTestLength);
299     auto y = A.withLength(largeTestLength);
300 
301     x[].randInPlaceWithElementRange(elementLow, elementHigh);
302     y[].randInPlaceWithElementRange(elementLow, elementHigh);
303 
304     sort(x[]);
305     sort(y[]);
306 
307     // associative
308     assert(equal(setIntersectionFast!(less)(x[], y[]),
309                  setIntersectionFast!(less)(y[], x[])));
310 
311     // same as current
312     assert(equal(setIntersection!(less)(x[], y[]),
313                  setIntersectionFast!(less)(x[], y[])));
314 
315     void testSetIntersection()
316     {
317         auto z = setIntersection!(less)(x[], y[]).collect!A;
318     }
319 
320     void testSetIntersectionNew()
321     {
322         auto z = setIntersectionFast!(less)(x[], y[]).collect!A;
323     }
324 
325     import std.datetime.stopwatch : benchmark;
326     import core.time : Duration;
327     immutable testCount = 10;
328     auto r = benchmark!(testSetIntersection,
329                         testSetIntersectionNew)(testCount);
330     import std.stdio : writeln;
331     import std.conv : to;
332 
333     version(show)
334     {
335         writeln("old testSetIntersection: ", to!Duration(r[0]));
336         writeln("new testSetIntersection: ", to!Duration(r[1]));
337     }
338 }
339 
340 @safe pure nothrow unittest
341 {
342     enum less = "a < b";
343     auto si = setIntersectionFast!(less)([1, 2, 3],
344                                          [1, 2, 3]);
345     const sic = si.save();
346     assert(si.equal([1, 2, 3]));
347 }
348 
349 // TODO remove this `MoveableSortedRange` and replace with Phobos' `SortedRange` in this buffer
350 struct MoveableSortedRange(Range, alias pred = "a < b")
351 if (isInputRange!Range)
352 {
353     import std.functional : binaryFun;
354 
355     private alias predFun = binaryFun!pred;
356     private bool geq(L, R)(L lhs, R rhs)
357     {
358         return !predFun(lhs, rhs);
359     }
360     private bool gt(L, R)(L lhs, R rhs)
361     {
362         return predFun(rhs, lhs);
363     }
364     private Range _input;
365 
366     // Undocummented because a clearer way to invoke is by calling
367     // assumeSorted.
368     this(Range input)
369     out
370     {
371         // moved out of the body as a workaround for Issue 12661
372         dbgVerifySorted();
373     }
374     do
375     {
376         import core.lifetime : move;
377         this._input = move(input); // TODO
378     }
379 
380     // Assertion only.
381     private void dbgVerifySorted()
382     {
383         if (!__ctfe)
384         debug
385         {
386             static if (isRandomAccessRange!Range && hasLength!Range)
387             {
388                 import core.bitop : bsr;
389                 import std.algorithm.sorting : isSorted;
390 
391                 // Check the sortedness of the input
392                 if (this._input.length < 2) return;
393 
394                 immutable size_t msb = bsr(this._input.length) + 1;
395                 assert(msb > 0 && msb <= this._input.length);
396                 immutable step = this._input.length / msb;
397                 auto st = stride(this._input, step);
398 
399                 assert(isSorted!pred(st), "Range is not sorted");
400             }
401         }
402     }
403 
404     /// Range primitives.
405     @property bool empty()             //const
406     {
407         return this._input.empty;
408     }
409 
410     /// Ditto
411     static if (isForwardRange!Range)
412     @property auto save()
413     {
414         // Avoid the constructor
415         typeof(this) result = this;
416         result._input = _input.save;
417         return result;
418     }
419 
420     /// Ditto
421     @property auto ref front()
422     {
423         return _input.front;
424     }
425 
426     /// Ditto
427     void popFront()
428     {
429         _input.popFront();
430     }
431 
432     /// Ditto
433     static if (isBidirectionalRange!Range)
434     {
435         @property auto ref back()
436         {
437             return _input.back;
438         }
439 
440         /// Ditto
441         void popBack()
442         {
443             _input.popBack();
444         }
445     }
446 
447     /// Ditto
448     static if (isRandomAccessRange!Range)
449         auto ref opIndex(size_t i)
450         {
451             return _input[i];
452         }
453 
454     /// Ditto
455     static if (hasSlicing!Range)
456         auto opSlice(size_t a, size_t b)
457         {
458             assert(
459                 a <= b,
460                 "Attempting to slice a SortedRange with a larger first argument than the second."
461             );
462             typeof(this) result = this;
463             result._input = _input[a .. b];// skip checking
464             return result;
465         }
466 
467     /// Ditto
468     static if (hasLength!Range)
469     {
470         @property size_t length()          //const
471         {
472             return _input.length;
473         }
474         alias opDollar = length;
475     }
476 
477 /**
478    Releases the controlled range and returns it.
479 */
480     auto release()
481     {
482         import core.lifetime : move;
483         return move(_input);
484     }
485 
486     // Assuming a predicate "test" that returns 0 for a left portion
487     // of the range and then 1 for the rest, returns the index at
488     // which the first 1 appears. Used internally by the search routines.
489     private size_t getTransitionIndex(SearchPolicy sp, alias test, V)(V v)
490     if (sp == SearchPolicy.binarySearch && isRandomAccessRange!Range && hasLength!Range)
491     {
492         size_t first = 0, count = _input.length;
493         while (count > 0)
494         {
495             immutable step = count / 2, it = first + step;
496             if (!test(_input[it], v))
497             {
498                 first = it + 1;
499                 count -= step + 1;
500             }
501             else
502             {
503                 count = step;
504             }
505         }
506         return first;
507     }
508 
509     // Specialization for trot and gallop
510     private size_t getTransitionIndex(SearchPolicy sp, alias test, V)(V v)
511     if ((sp == SearchPolicy.trot || sp == SearchPolicy.gallop)
512         && isRandomAccessRange!Range)
513     {
514         if (empty || test(front, v)) return 0;
515         immutable count = length;
516         if (count == 1) return 1;
517         size_t below = 0, above = 1, step = 2;
518         while (!test(_input[above], v))
519         {
520             // Still too small, update below and increase gait
521             below = above;
522             immutable next = above + step;
523             if (next >= count)
524             {
525                 // Overshot - the next step took us beyond the end. So
526                 // now adjust next and simply exit the loop to do the
527                 // binary search thingie.
528                 above = count;
529                 break;
530             }
531             // Still in business, increase step and continue
532             above = next;
533             static if (sp == SearchPolicy.trot)
534                 ++step;
535             else
536                 step <<= 1;
537         }
538         return below + this[below .. above].getTransitionIndex!(
539             SearchPolicy.binarySearch, test, V)(v);
540     }
541 
542     // Specialization for trotBackwards and gallopBackwards
543     private size_t getTransitionIndex(SearchPolicy sp, alias test, V)(V v)
544     if ((sp == SearchPolicy.trotBackwards || sp == SearchPolicy.gallopBackwards)
545         && isRandomAccessRange!Range)
546     {
547         immutable count = length;
548         if (empty || !test(back, v)) return count;
549         if (count == 1) return 0;
550         size_t below = count - 2, above = count - 1, step = 2;
551         while (test(_input[below], v))
552         {
553             // Still too large, update above and increase gait
554             above = below;
555             if (below < step)
556             {
557                 // Overshot - the next step took us beyond the end. So
558                 // now adjust next and simply fall through to do the
559                 // binary search thingie.
560                 below = 0;
561                 break;
562             }
563             // Still in business, increase step and continue
564             below -= step;
565             static if (sp == SearchPolicy.trot)
566                 ++step;
567             else
568                 step <<= 1;
569         }
570         return below + this[below .. above].getTransitionIndex!(
571             SearchPolicy.binarySearch, test, V)(v);
572     }
573 
574 // lowerBound
575 /**
576    This function uses a search with policy $(D sp) to find the
577    largest left subrange on which $(D pred(x, value)) is `true` for
578    all $(D x) (e.g., if $(D pred) is "less than", returns the portion of
579    the range with elements strictly smaller than `value`). The search
580    schedule and its complexity are documented in
581    $(LREF SearchPolicy).  See also STL's
582    $(HTTP sgi.com/tech/stl/lower_bound.html, lower_bound).
583 */
584     auto lowerBound(SearchPolicy sp = SearchPolicy.binarySearch, V)(V value)
585     if (isTwoWayCompatible!(predFun, ElementType!Range, V)
586          && hasSlicing!Range)
587     {
588         return this[0 .. getTransitionIndex!(sp, geq)(value)];
589     }
590 
591 // upperBound
592 /**
593 This function searches with policy $(D sp) to find the largest right
594 subrange on which $(D pred(value, x)) is `true` for all $(D x)
595 (e.g., if $(D pred) is "less than", returns the portion of the range
596 with elements strictly greater than `value`). The search schedule
597 and its complexity are documented in $(LREF SearchPolicy).
598 
599 For ranges that do not offer random access, $(D SearchPolicy.linear)
600 is the only policy allowed (and it must be specified explicitly lest it exposes
601 user code to unexpected inefficiencies). For random-access searches, all
602 policies are allowed, and $(D SearchPolicy.binarySearch) is the default.
603 
604 See_Also: STL's $(HTTP sgi.com/tech/stl/lower_bound.html,upper_bound).
605 */
606     auto upperBound(SearchPolicy sp = SearchPolicy.binarySearch, V)(V value)
607     if (isTwoWayCompatible!(predFun, ElementType!Range, V))
608     {
609         static assert(hasSlicing!Range || sp == SearchPolicy.linear,
610             "Specify SearchPolicy.linear explicitly for "
611             ~ typeof(this).stringof);
612         static if (sp == SearchPolicy.linear)
613         {
614             for (; !_input.empty && !predFun(value, _input.front);
615                  _input.popFront())
616             {
617             }
618             return this;
619         }
620         else
621         {
622             return this[getTransitionIndex!(sp, gt)(value) .. length];
623         }
624     }
625 
626 
627 // equalRange
628 /**
629    Returns the subrange containing all elements $(D e) for which both $(D
630    pred(e, value)) and $(D pred(value, e)) evaluate to $(D false) (e.g.,
631    if $(D pred) is "less than", returns the portion of the range with
632    elements equal to `value`). Uses a classic binary search with
633    interval halving until it finds a value that satisfies the condition,
634    then uses $(D SearchPolicy.gallopBackwards) to find the left boundary
635    and $(D SearchPolicy.gallop) to find the right boundary. These
636    policies are justified by the fact that the two boundaries are likely
637    to be near the first found value (i.e., equal ranges are relatively
638    small). Completes the entire search in $(BIGOH log(n)) time. See also
639    STL's $(HTTP sgi.com/tech/stl/equal_range.html, equal_range).
640 */
641     auto equalRange(V)(V value)
642     if (isTwoWayCompatible!(predFun, ElementType!Range, V)
643         && isRandomAccessRange!Range)
644     {
645         size_t first = 0, count = _input.length;
646         while (count > 0)
647         {
648             immutable step = count / 2;
649             auto it = first + step;
650             if (predFun(_input[it], value))
651             {
652                 // Less than value, bump left bound up
653                 first = it + 1;
654                 count -= step + 1;
655             }
656             else if (predFun(value, _input[it]))
657             {
658                 // Greater than value, chop count
659                 count = step;
660             }
661             else
662             {
663                 // Equal to value, do binary searches in the
664                 // leftover portions
665                 // Gallop towards the left end as it's likely nearby
666                 immutable left = first
667                     + this[first .. it]
668                     .lowerBound!(SearchPolicy.gallopBackwards)(value).length;
669                 first += count;
670                 // Gallop towards the right end as it's likely nearby
671                 immutable right = first
672                     - this[it + 1 .. first]
673                     .upperBound!(SearchPolicy.gallop)(value).length;
674                 return this[left .. right];
675             }
676         }
677         return this.init;
678     }
679 
680 // trisect
681 /**
682 Returns a tuple $(D r) such that $(D r[0]) is the same as the result
683 of $(D lowerBound(value)), $(D r[1]) is the same as the result of $(D
684 equalRange(value)), and $(D r[2]) is the same as the result of $(D
685 upperBound(value)). The call is faster than computing all three
686 separately. Uses a search schedule similar to $(D
687 equalRange). Completes the entire search in $(BIGOH log(n)) time.
688 */
689     auto trisect(V)(V value)
690     if (isTwoWayCompatible!(predFun, ElementType!Range, V)
691         && isRandomAccessRange!Range && hasLength!Range)
692     {
693         import std.typecons : tuple;
694         size_t first = 0, count = _input.length;
695         while (count > 0)
696         {
697             immutable step = count / 2;
698             auto it = first + step;
699             if (predFun(_input[it], value))
700             {
701                 // Less than value, bump left bound up
702                 first = it + 1;
703                 count -= step + 1;
704             }
705             else if (predFun(value, _input[it]))
706             {
707                 // Greater than value, chop count
708                 count = step;
709             }
710             else
711             {
712                 // Equal to value, do binary searches in the
713                 // leftover portions
714                 // Gallop towards the left end as it's likely nearby
715                 immutable left = first
716                     + this[first .. it]
717                     .lowerBound!(SearchPolicy.gallopBackwards)(value).length;
718                 first += count;
719                 // Gallop towards the right end as it's likely nearby
720                 immutable right = first
721                     - this[it + 1 .. first]
722                     .upperBound!(SearchPolicy.gallop)(value).length;
723                 return tuple(this[0 .. left], this[left .. right],
724                         this[right .. length]);
725             }
726         }
727         // No equal element was found
728         return tuple(this[0 .. first], this.init, this[first .. length]);
729     }
730 
731 // contains
732 /**
733 Returns `true` if and only if `value` can be found in $(D
734 range), which is assumed to be sorted. Performs $(BIGOH log(r.length))
735 evaluations of $(D pred). See also STL's $(HTTP
736 sgi.com/tech/stl/binary_search.html, binary_search).
737  */
738 
739     bool contains(V)(V value)
740     if (isRandomAccessRange!Range)
741     {
742         if (empty) return false;
743         immutable i = getTransitionIndex!(SearchPolicy.binarySearch, geq)(value);
744         if (i >= length) return false;
745         return !predFun(value, _input[i]);
746     }
747 
748 // groupBy
749 /**
750 Returns a range of subranges of elements that are equivalent according to the
751 sorting relation.
752  */
753     auto groupBy()()
754     {
755         import std.algorithm.iteration : chunkBy;
756         return _input.chunkBy!((a, b) => !predFun(a, b) && !predFun(b, a));
757     }
758 }