1 /** Extend std.algorithm.setopts to also operate on set- and map-like
2     containers/ranges.
3 
4     See_Also: http://forum.dlang.org/post/nvd09v$24e9$1@digitalmars.com
5 */
6 module nxt.setops_ex;
7 
8 // version = show;
9 
10 /** Specialization for `std.algorithm.setopts.setUnion` for AA. */
11 auto setUnionUpdate(T1, T2)(T1 a, T2 b) @trusted
12 if (isAA!T1 &&
13     isAA!T2)
14 {
15     if (a.length < b.length)
16     {
17         return setUnionHelper(a, b);
18     }
19     else
20     {
21         return setUnionHelper(b, a);
22     }
23 }
24 
25 /** Helper function for `setUnionUpdate` that assumes `small` has shorter length than
26     `large` .
27 */
28 private static auto setUnionHelper(Small, Large)(const Small small, Large large)
29 {
30     Large united = large.dup;   // TODO this shallow copy prevents large from being `const`
31     foreach (const ref e; small.byKeyValue)
32     {
33         if (auto hitPtr = e.key in large)
34         {
35             (*hitPtr) = e.value; // TODO this potentially changes the value of
36         }
37         else
38         {
39             united[e.key] = e.value;
40         }
41     }
42     return united;
43 }
44 
45 /** Is `true` iff `Set` is set-like container, that is provides membership
46  * checking via the `in` operator or `contains`.
47  *
48  * TODO Move to Phobos std.traits
49  */
50 template hasContains(Set)
51 {
52     import std.traits : hasMember;
53     enum isSetOf = hasMember!(Set, "contains"); // TODO extend to check `in` operator aswell
54 }
55 
56 /** Is `true` iff `Map` is map-like container, that is provides membership
57  * checking via the `in` operator or `contains`.
58  *
59  * TODO Move to Phobos std.traits
60  */
61 template isAA(Map)
62 {
63     import std.traits : isAssociativeArray;
64     enum isAA = isAssociativeArray!Map; // TODO check if in operator returns reference to value
65 }
66 
67 /// union of associative array (via keys)
68 @safe pure unittest
69 {
70     alias Map = string[int];
71 
72     Map x = [0 : "a", 1 : "b"];
73     Map y = [2 : "c"];
74 
75     Map c = [0 : "a", 1 : "b", 2 : "c"];
76 
77     // test associativity
78     assert(setUnionUpdate(x, y) == c);
79     assert(setUnionUpdate(y, x) == c);
80 }
81 
82 import std.traits : CommonType;
83 import std.range.primitives;
84 import std.meta : allSatisfy, staticMap;
85 import std.functional : binaryFun;
86 import std.range : SearchPolicy;
87 import nxt.range_ex : haveCommonElementType;
88 
89 /** Intersection of two or more ranges of type `Rs`.
90  *
91  * See_Also: https://forum.dlang.org/post/puwffthbqaktlqnourrs@forum.dlang.org
92  */
93 private struct SetIntersectionFast(alias less = "a < b",
94                                    SearchPolicy preferredSearchPolicy = SearchPolicy.gallop,
95                                    Rs...)
96 if (Rs.length >= 2 &&
97     allSatisfy!(isInputRange, Rs) &&
98     haveCommonElementType!Rs)
99 {
100 private:
101     Rs _inputs;
102     alias comp = binaryFun!less;
103     alias ElementType = CommonType!(staticMap!(.ElementType, Rs));
104 
105     // Positions to the first elements that are all equal
106     void adjustPosition()
107     {
108         if (empty) return;
109 
110         auto compsLeft = Rs.length; // number of compares left
111         static if (Rs.length > 1) while (true)
112         {
113             foreach (i, ref r; _inputs)
114             {
115                 alias next = _inputs[(i + 1) % Rs.length]; // requires copying of range
116 
117                 // TODO Use upperBound only when next.length / r.length > 12
118 
119                 import std.range.primitives : isRandomAccessRange;
120                 static if (allSatisfy!(isRandomAccessRange, typeof(next)))
121                 {
122                     import std.range : assumeSorted;
123 
124                     // TODO remove need for this hack
125                     static if (less == "a < b")
126                     {
127                         enum lessEq = "a <= b";
128                     }
129                     else static if (less == "a > b")
130                     {
131                         enum lessEq = "a >= b";
132                     }
133 
134                     // TODO can we merge thsse two lines two one single assignment from nextUpperBound to next
135                     auto nextUpperBound = next.assumeSorted!lessEq.upperBound!preferredSearchPolicy(r.front);
136                     next = next[$ - nextUpperBound.length .. $];
137 
138                     if (next.empty)
139                     {
140                         return; // next became empty, so everything becomes empty
141                     }
142                     else if (next.front != r.front)
143                     {
144                         compsLeft = Rs.length; // we need to start counting comparing again starting with next.front
145                     }
146                 }
147                 else
148                 {
149                     if (comp(next.front, r.front))
150                     {
151                         do
152                         {
153                             next.popFront();
154                             if (next.empty) return;
155                         }
156                         while (comp(next.front, r.front));
157                         compsLeft = Rs.length;
158                     }
159                 }
160                 if (--compsLeft == 0) return; // count down, and if we have made Rs.length iterations we are compsLeft finding a common front element
161             }
162         }
163     }
164 
165 public:
166     ///
167     this(Rs inputs)
168     {
169         import std.functional : forward;
170         this._inputs = forward!inputs; // TODO remove `forward` when compiler does it for us
171         // position to the first element
172         adjustPosition();
173     }
174 
175     ///
176     @property bool empty()
177     {
178         foreach (ref r; _inputs)
179         {
180             if (r.empty) return true;
181         }
182         return false;
183     }
184 
185     ///
186     void popFront()
187     {
188         assert(!empty);
189         static if (Rs.length > 1) foreach (i, ref r; _inputs)
190         {
191             alias next = _inputs[(i + 1) % Rs.length];
192             assert(!comp(r.front, next.front));
193         }
194 
195         foreach (ref r; _inputs)
196         {
197             r.popFront();
198         }
199         adjustPosition();
200     }
201 
202     ///
203     @property ElementType front()
204     {
205         assert(!empty);
206         return _inputs[0].front;
207     }
208 
209     static if (allSatisfy!(isForwardRange, Rs))
210     {
211         ///
212         @property SetIntersectionFast save()
213         {
214             auto ret = this;
215             foreach (i, ref r; _inputs)
216             {
217                 ret._inputs[i] = r.save;
218             }
219             return ret;
220         }
221     }
222 }
223 
224 import core.internal.traits : Unqual;
225 
226 auto assumeMoveableSorted(alias pred = "a < b", R)(R r)
227 if (isInputRange!(Unqual!R))
228 {
229     import core.lifetime : move;
230     return MoveableSortedRange!(Unqual!R, pred)(move(r)); // TODO remove `move` when compiler does it for us
231 }
232 
233 /** Get intersection of `ranges`.
234  *
235  * See_Also: https://forum.dlang.org/post/puwffthbqaktlqnourrs@forum.dlang.org
236  */
237 MoveableSortedRange!(SetIntersectionFast!(less, preferredSearchPolicy, Rs))
238 setIntersectionFast(alias less = "a < b",
239                     SearchPolicy preferredSearchPolicy = SearchPolicy.gallop,
240                     Rs...)(Rs ranges)
241 if (Rs.length >= 2 &&
242     allSatisfy!(isInputRange, Rs) &&
243     haveCommonElementType!Rs)
244 {
245     // TODO Remove need for these switch cases if this can be fixed:
246     // http://forum.dlang.org/post/pknonazfniihvpicxbld@forum.dlang.org
247     import std.range : assumeSorted;
248     static if (Rs.length == 2)
249     {
250         import core.lifetime : move;
251         return assumeMoveableSorted(SetIntersectionFast!(less,
252                                                          preferredSearchPolicy,
253                                                          Rs)(move(ranges[0]), // TODO remove `move` when compiler does it for us
254                                                              move(ranges[1]))); // TODO remove `move` when compiler does it for us
255     }
256     else
257     {
258         import std.functional : forward;
259         return assumeMoveableSorted(SetIntersectionFast!(less,
260                                                          preferredSearchPolicy,
261                                                          Rs)(forward!ranges)); // TODO remove `forward` when compiler does it for us
262     }
263 }
264 
265 @safe unittest
266 {
267     import std.algorithm.comparison : equal;
268     import std.algorithm.sorting : sort;
269     import std.algorithm.setops : setIntersection;
270     import nxt.random_ex : randInPlaceWithElementRange;
271     import nxt.dynamic_array : DynamicArray;
272     import nxt.algorithm_ex : collect;
273 
274     alias E = ulong;
275     alias A = DynamicArray!E;
276 
277     auto a0 = A();
278     auto a1 = A(1);
279 
280     enum less = "a < b";
281 
282     auto s0 = setIntersectionFast!(less)(a0[], a0[]);
283     assert(s0.equal(a0[]));
284 
285     auto s1 = setIntersectionFast!(less)(a1[], a1[]);
286     assert(s1.equal(a1[]));
287 
288     immutable smallTestLength = 1000;
289     immutable factor = 12; // this is the magical limit on my laptop when performance of `upperBound` beats standard implementation
290     immutable largeTestLength = factor*smallTestLength;
291     E elementLow = 0;
292     E elementHigh = 10_000_000;
293     auto x = A.withLength(smallTestLength);
294     auto y = A.withLength(largeTestLength);
295 
296     x[].randInPlaceWithElementRange(elementLow, elementHigh);
297     y[].randInPlaceWithElementRange(elementLow, elementHigh);
298 
299     sort(x[]);
300     sort(y[]);
301 
302     // associative
303     assert(equal(setIntersectionFast!(less)(x[], y[]),
304                  setIntersectionFast!(less)(y[], x[])));
305 
306     // same as current
307     assert(equal(setIntersection!(less)(x[], y[]),
308                  setIntersectionFast!(less)(x[], y[])));
309 
310     void testSetIntersection()
311     {
312         auto z = setIntersection!(less)(x[], y[]).collect!A;
313     }
314 
315     void testSetIntersectionNew()
316     {
317         auto z = setIntersectionFast!(less)(x[], y[]).collect!A;
318     }
319 
320     import std.datetime.stopwatch : benchmark;
321     import core.time : Duration;
322     immutable testCount = 10;
323     auto r = benchmark!(testSetIntersection,
324                         testSetIntersectionNew)(testCount);
325     import std.stdio : writeln;
326     import std.conv : to;
327 
328     version(show)
329     {
330         writeln("old testSetIntersection: ", to!Duration(r[0]));
331         writeln("new testSetIntersection: ", to!Duration(r[1]));
332     }
333 }
334 
335 @safe pure nothrow unittest
336 {
337     import std.algorithm.comparison : equal;
338     enum less = "a < b";
339     auto si = setIntersectionFast!(less)([1, 2, 3],
340                                          [1, 2, 3]);
341     const sic = si.save();
342     assert(si.equal([1, 2, 3]));
343 }
344 
345 // TODO remove this `MoveableSortedRange` and replace with Phobos' `SortedRange` in this buffer
346 struct MoveableSortedRange(Range, alias pred = "a < b")
347 if (isInputRange!Range)
348 {
349     import std.functional : binaryFun;
350 
351     private alias predFun = binaryFun!pred;
352     private bool geq(L, R)(L lhs, R rhs)
353     {
354         return !predFun(lhs, rhs);
355     }
356     private bool gt(L, R)(L lhs, R rhs)
357     {
358         return predFun(rhs, lhs);
359     }
360     private Range _input;
361 
362     // Undocummented because a clearer way to invoke is by calling
363     // assumeSorted.
364     this(Range input)
365     out
366     {
367         // moved out of the body as a workaround for Issue 12661
368         dbgVerifySorted();
369     }
370     do
371     {
372         import core.lifetime : move;
373         this._input = move(input); // TODO
374     }
375 
376     // Assertion only.
377     private void dbgVerifySorted()
378     {
379         if (!__ctfe)
380         debug
381         {
382             static if (isRandomAccessRange!Range && hasLength!Range)
383             {
384                 import core.bitop : bsr;
385                 import std.algorithm.sorting : isSorted;
386 
387                 // Check the sortedness of the input
388                 if (this._input.length < 2) return;
389 
390                 immutable size_t msb = bsr(this._input.length) + 1;
391                 assert(msb > 0 && msb <= this._input.length);
392                 immutable step = this._input.length / msb;
393                 auto st = stride(this._input, step);
394 
395                 assert(isSorted!pred(st), "Range is not sorted");
396             }
397         }
398     }
399 
400     /// Range primitives.
401     @property bool empty()             //const
402     {
403         return this._input.empty;
404     }
405 
406     /// Ditto
407     static if (isForwardRange!Range)
408     @property auto save()
409     {
410         // Avoid the constructor
411         typeof(this) result = this;
412         result._input = _input.save;
413         return result;
414     }
415 
416     /// Ditto
417     @property auto ref front()
418     {
419         return _input.front;
420     }
421 
422     /// Ditto
423     void popFront()
424     {
425         _input.popFront();
426     }
427 
428     /// Ditto
429     static if (isBidirectionalRange!Range)
430     {
431         @property auto ref back()
432         {
433             return _input.back;
434         }
435 
436         /// Ditto
437         void popBack()
438         {
439             _input.popBack();
440         }
441     }
442 
443     /// Ditto
444     static if (isRandomAccessRange!Range)
445         auto ref opIndex(size_t i)
446         {
447             return _input[i];
448         }
449 
450     /// Ditto
451     static if (hasSlicing!Range)
452         auto opSlice(size_t a, size_t b)
453         {
454             assert(
455                 a <= b,
456                 "Attempting to slice a SortedRange with a larger first argument than the second."
457             );
458             typeof(this) result = this;
459             result._input = _input[a .. b];// skip checking
460             return result;
461         }
462 
463     /// Ditto
464     static if (hasLength!Range)
465     {
466         @property size_t length()          //const
467         {
468             return _input.length;
469         }
470         alias opDollar = length;
471     }
472 
473 /**
474    Releases the controlled range and returns it.
475 */
476     auto release()
477     {
478         import core.lifetime : move;
479         return move(_input);
480     }
481 
482     // Assuming a predicate "test" that returns 0 for a left portion
483     // of the range and then 1 for the rest, returns the index at
484     // which the first 1 appears. Used internally by the search routines.
485     private size_t getTransitionIndex(SearchPolicy sp, alias test, V)(V v)
486     if (sp == SearchPolicy.binarySearch && isRandomAccessRange!Range && hasLength!Range)
487     {
488         size_t first = 0, count = _input.length;
489         while (count > 0)
490         {
491             immutable step = count / 2, it = first + step;
492             if (!test(_input[it], v))
493             {
494                 first = it + 1;
495                 count -= step + 1;
496             }
497             else
498             {
499                 count = step;
500             }
501         }
502         return first;
503     }
504 
505     // Specialization for trot and gallop
506     private size_t getTransitionIndex(SearchPolicy sp, alias test, V)(V v)
507     if ((sp == SearchPolicy.trot || sp == SearchPolicy.gallop)
508         && isRandomAccessRange!Range)
509     {
510         if (empty || test(front, v)) return 0;
511         immutable count = length;
512         if (count == 1) return 1;
513         size_t below = 0, above = 1, step = 2;
514         while (!test(_input[above], v))
515         {
516             // Still too small, update below and increase gait
517             below = above;
518             immutable next = above + step;
519             if (next >= count)
520             {
521                 // Overshot - the next step took us beyond the end. So
522                 // now adjust next and simply exit the loop to do the
523                 // binary search thingie.
524                 above = count;
525                 break;
526             }
527             // Still in business, increase step and continue
528             above = next;
529             static if (sp == SearchPolicy.trot)
530                 ++step;
531             else
532                 step <<= 1;
533         }
534         return below + this[below .. above].getTransitionIndex!(
535             SearchPolicy.binarySearch, test, V)(v);
536     }
537 
538     // Specialization for trotBackwards and gallopBackwards
539     private size_t getTransitionIndex(SearchPolicy sp, alias test, V)(V v)
540     if ((sp == SearchPolicy.trotBackwards || sp == SearchPolicy.gallopBackwards)
541         && isRandomAccessRange!Range)
542     {
543         immutable count = length;
544         if (empty || !test(back, v)) return count;
545         if (count == 1) return 0;
546         size_t below = count - 2, above = count - 1, step = 2;
547         while (test(_input[below], v))
548         {
549             // Still too large, update above and increase gait
550             above = below;
551             if (below < step)
552             {
553                 // Overshot - the next step took us beyond the end. So
554                 // now adjust next and simply fall through to do the
555                 // binary search thingie.
556                 below = 0;
557                 break;
558             }
559             // Still in business, increase step and continue
560             below -= step;
561             static if (sp == SearchPolicy.trot)
562                 ++step;
563             else
564                 step <<= 1;
565         }
566         return below + this[below .. above].getTransitionIndex!(
567             SearchPolicy.binarySearch, test, V)(v);
568     }
569 
570 // lowerBound
571 /**
572    This function uses a search with policy $(D sp) to find the
573    largest left subrange on which $(D pred(x, value)) is `true` for
574    all $(D x) (e.g., if $(D pred) is "less than", returns the portion of
575    the range with elements strictly smaller than `value`). The search
576    schedule and its complexity are documented in
577    $(LREF SearchPolicy).  See also STL's
578    $(HTTP sgi.com/tech/stl/lower_bound.html, lower_bound).
579 */
580     auto lowerBound(SearchPolicy sp = SearchPolicy.binarySearch, V)(V value)
581     if (isTwoWayCompatible!(predFun, ElementType!Range, V)
582          && hasSlicing!Range)
583     {
584         return this[0 .. getTransitionIndex!(sp, geq)(value)];
585     }
586 
587 // upperBound
588 /**
589 This function searches with policy $(D sp) to find the largest right
590 subrange on which $(D pred(value, x)) is `true` for all $(D x)
591 (e.g., if $(D pred) is "less than", returns the portion of the range
592 with elements strictly greater than `value`). The search schedule
593 and its complexity are documented in $(LREF SearchPolicy).
594 
595 For ranges that do not offer random access, $(D SearchPolicy.linear)
596 is the only policy allowed (and it must be specified explicitly lest it exposes
597 user code to unexpected inefficiencies). For random-access searches, all
598 policies are allowed, and $(D SearchPolicy.binarySearch) is the default.
599 
600 See_Also: STL's $(HTTP sgi.com/tech/stl/lower_bound.html,upper_bound).
601 */
602     auto upperBound(SearchPolicy sp = SearchPolicy.binarySearch, V)(V value)
603     if (isTwoWayCompatible!(predFun, ElementType!Range, V))
604     {
605         static assert(hasSlicing!Range || sp == SearchPolicy.linear,
606             "Specify SearchPolicy.linear explicitly for "
607             ~ typeof(this).stringof);
608         static if (sp == SearchPolicy.linear)
609         {
610             for (; !_input.empty && !predFun(value, _input.front);
611                  _input.popFront())
612             {
613             }
614             return this;
615         }
616         else
617         {
618             return this[getTransitionIndex!(sp, gt)(value) .. length];
619         }
620     }
621 
622 
623 // equalRange
624 /**
625    Returns the subrange containing all elements $(D e) for which both $(D
626    pred(e, value)) and $(D pred(value, e)) evaluate to $(D false) (e.g.,
627    if $(D pred) is "less than", returns the portion of the range with
628    elements equal to `value`). Uses a classic binary search with
629    interval halving until it finds a value that satisfies the condition,
630    then uses $(D SearchPolicy.gallopBackwards) to find the left boundary
631    and $(D SearchPolicy.gallop) to find the right boundary. These
632    policies are justified by the fact that the two boundaries are likely
633    to be near the first found value (i.e., equal ranges are relatively
634    small). Completes the entire search in $(BIGOH log(n)) time. See also
635    STL's $(HTTP sgi.com/tech/stl/equal_range.html, equal_range).
636 */
637     auto equalRange(V)(V value)
638     if (isTwoWayCompatible!(predFun, ElementType!Range, V)
639         && isRandomAccessRange!Range)
640     {
641         size_t first = 0, count = _input.length;
642         while (count > 0)
643         {
644             immutable step = count / 2;
645             auto it = first + step;
646             if (predFun(_input[it], value))
647             {
648                 // Less than value, bump left bound up
649                 first = it + 1;
650                 count -= step + 1;
651             }
652             else if (predFun(value, _input[it]))
653             {
654                 // Greater than value, chop count
655                 count = step;
656             }
657             else
658             {
659                 // Equal to value, do binary searches in the
660                 // leftover portions
661                 // Gallop towards the left end as it's likely nearby
662                 immutable left = first
663                     + this[first .. it]
664                     .lowerBound!(SearchPolicy.gallopBackwards)(value).length;
665                 first += count;
666                 // Gallop towards the right end as it's likely nearby
667                 immutable right = first
668                     - this[it + 1 .. first]
669                     .upperBound!(SearchPolicy.gallop)(value).length;
670                 return this[left .. right];
671             }
672         }
673         return this.init;
674     }
675 
676 // trisect
677 /**
678 Returns a tuple $(D r) such that $(D r[0]) is the same as the result
679 of $(D lowerBound(value)), $(D r[1]) is the same as the result of $(D
680 equalRange(value)), and $(D r[2]) is the same as the result of $(D
681 upperBound(value)). The call is faster than computing all three
682 separately. Uses a search schedule similar to $(D
683 equalRange). Completes the entire search in $(BIGOH log(n)) time.
684 */
685     auto trisect(V)(V value)
686     if (isTwoWayCompatible!(predFun, ElementType!Range, V)
687         && isRandomAccessRange!Range && hasLength!Range)
688     {
689         import std.typecons : tuple;
690         size_t first = 0, count = _input.length;
691         while (count > 0)
692         {
693             immutable step = count / 2;
694             auto it = first + step;
695             if (predFun(_input[it], value))
696             {
697                 // Less than value, bump left bound up
698                 first = it + 1;
699                 count -= step + 1;
700             }
701             else if (predFun(value, _input[it]))
702             {
703                 // Greater than value, chop count
704                 count = step;
705             }
706             else
707             {
708                 // Equal to value, do binary searches in the
709                 // leftover portions
710                 // Gallop towards the left end as it's likely nearby
711                 immutable left = first
712                     + this[first .. it]
713                     .lowerBound!(SearchPolicy.gallopBackwards)(value).length;
714                 first += count;
715                 // Gallop towards the right end as it's likely nearby
716                 immutable right = first
717                     - this[it + 1 .. first]
718                     .upperBound!(SearchPolicy.gallop)(value).length;
719                 return tuple(this[0 .. left], this[left .. right],
720                         this[right .. length]);
721             }
722         }
723         // No equal element was found
724         return tuple(this[0 .. first], this.init, this[first .. length]);
725     }
726 
727 // contains
728 /**
729 Returns `true` if and only if `value` can be found in $(D
730 range), which is assumed to be sorted. Performs $(BIGOH log(r.length))
731 evaluations of $(D pred). See also STL's $(HTTP
732 sgi.com/tech/stl/binary_search.html, binary_search).
733  */
734 
735     bool contains(V)(V value)
736     if (isRandomAccessRange!Range)
737     {
738         if (empty) return false;
739         immutable i = getTransitionIndex!(SearchPolicy.binarySearch, geq)(value);
740         if (i >= length) return false;
741         return !predFun(value, _input[i]);
742     }
743 
744 // groupBy
745 /**
746 Returns a range of subranges of elements that are equivalent according to the
747 sorting relation.
748  */
749     auto groupBy()()
750     {
751         import std.algorithm.iteration : chunkBy;
752         return _input.chunkBy!((a, b) => !predFun(a, b) && !predFun(b, a));
753     }
754 }