1 /** Integer Sorting Algorithms.
2     Copyright: Per Nordlöw 2018-.
3     License: $(WEB boost.org/LICENSE_1_0.txt, Boost License 1.0).
4     Authors: $(WEB Per Nordlöw)
5  */
6 module nxt.integer_sorting;
7 
8 import std.range.primitives : isRandomAccessRange, ElementType;
9 import std.traits : isNumeric;
10 import std.meta : AliasSeq;
11 
12 import nxt.bijections;
13 
14 /** Radix sort of `input`.
15 
16     Note that this implementation of non-inplace radix sort only requires
17     `input` to be a `BidirectionalRange` not a `RandomAccessRange`.
18 
19     Note that `input` can be a `BidirectionalRange` aswell as
20     `RandomAccessRange`.
21 
22     `radixBitCount` is the number of bits in radix (digit)
23 
24     TODO: make `radixBitCount` a template parameter either 8 or 16,
25     ElementType.sizeof must be a multiple of radixBitCount
26 
27     TODO: input[] = y[] not needed when input is mutable
28 
29     TODO: Restrict fun.
30 
31     TODO: Choose fastDigitDiscardal based on elementMin and elementMax (if they
32     are given)
33 
34     See_Also: https://probablydance.com/2016/12/27/i-wrote-a-faster-sorting-algorithm/
35     See_Also: https://github.com/skarupke/ska_sort/blob/master/ska_sort.hpp
36     See_Also: http://forum.dlang.org/thread/vmytpazcusauxypkwdbn@forum.dlang.org#post-vmytpazcusauxypkwdbn:40forum.dlang.org
37  */
38 auto radixSort(R,
39                alias fun = "a",
40                bool descending = false,
41                bool requestDigitDiscardal = false,
42                bool inPlace = false)(R input,
43                                      /* ElementType!R elementMin = ElementType!(R).max, */
44                                      /* ElementType!R elementMax = ElementType!(R).min */)
45 
46     @trusted
47 if (isRandomAccessRange!R &&
48     (isNumeric!(ElementType!R)))
49 {
50     import std.range : assumeSorted;
51     import std.algorithm.sorting : isSorted; // TODO: move this to radixSort when know how map less to descending
52     import std.algorithm.comparison : min, max;
53     import std.range.primitives : front;
54 
55     immutable n = input.length; // number of elements
56     alias E = ElementType!R;
57     enum elementBitCount = 8*E.sizeof; // total number of bits needed to code each element
58 
59     /* Lookup number of radix bits from sizeof `ElementType`.
60        These give optimal performance on Intel Core i7.
61     */
62     static if (elementBitCount == 8 ||
63                elementBitCount == 24)
64     {
65         enum radixBitCount = 8;
66     }
67     else static if (elementBitCount == 16 ||
68                     elementBitCount == 32 ||
69                     elementBitCount == 64)
70     {
71         enum radixBitCount = 16;
72     }
73     else
74     {
75         static assert(0, "TODO: handle element type " ~ e.stringof);
76     }
77 
78     // TODO: activate this: subtract min from all values and then immutable uint elementBitCount = is_min(a_max) ? 8*sizeof(E) : binlog(a_max); and add it back.
79     enum digitCount = elementBitCount / radixBitCount;         // number of `digitCount` in radix `radixBitCount`
80     static assert(elementBitCount % radixBitCount == 0,
81                   "Precision of ElementType must be evenly divisble by bit-precision of Radix.");
82 
83     enum doDigitDiscardal = requestDigitDiscardal && digitCount >= 2;
84 
85     enum radix = cast(typeof(radixBitCount))1 << radixBitCount;    // bin count
86     enum mask = radix-1;                                     // radix bit mask
87 
88     alias UE = typeof(input.front.bijectToUnsigned); // get unsigned integer type of same precision as \tparam E.
89 
90     import nxt.fixed_dynamic_array : FixedDynamicArray;
91 
92     static if (inPlace) // most-significant digit (MSD) first in-place radix sort
93     {
94         static assert(!descending, "TODO: Implement descending version");
95 
96         foreach (immutable digitOffsetReversed; 0 .. digitCount) // for each `digitOffset` (in base `radix`) starting with least significant (LSD-first)
97         {
98             immutable digitOffset = digitCount - 1 - digitOffsetReversed;
99             immutable digitBitshift = digitOffset*radixBitCount; // digit bit shift
100 
101             // [lowOffsets[i], highOffsets[i]] will become slices into `input`
102             size_t[radix] lowOffsets; // low offsets for each bin
103             size_t[radix] highOffsets; // high offsets for each bin
104 
105             // calculate counts
106             foreach (immutable j; 0 .. n) // for each element index `j` in `input`
107             {
108                 immutable UE currentUnsignedValue = cast(UE)input[j].bijectToUnsigned(descending);
109                 immutable i = (currentUnsignedValue >> digitBitshift) & mask; // digit (index)
110                 ++highOffsets[i];   // increase histogram bin counter
111             }
112 
113             // bin boundaries: accumulate bin counters array
114             lowOffsets[0] = 0;             // first low is always zero
115             foreach (immutable j; 1 .. radix) // for each successive bin counter
116             {
117                 lowOffsets[j] = highOffsets[j - 1]; // previous roof becomes current floor
118                 highOffsets[j] += highOffsets[j - 1]; // accumulate bin counter
119             }
120             assert(highOffsets[radix - 1] == n); // should equal high offset of last bin
121         }
122 
123         // /** \em unstable in-place (permutate) reorder/sort `input`
124         //  * access `input`'s elements in \em reverse to \em reuse filled caches from previous forward iteration.
125         //  * \see `in_place_indexed_reorder`
126         //  */
127         // for (int r = radix - 1; r >= 0; --r) // for each radix digit r in reverse order (cache-friendly)
128         // {
129         //     while (binStat[r])  // as long as elements left in r:th bucket
130         //     {
131         //         immutable uint i0 = binStat[r].pop_back(); // index to first element of permutation
132         //         immutable E    e0 = input[i0]; // value of first/current element of permutation
133         //         while (true)
134         //         {
135         //             immutable int rN = (e0.bijectToUnsigned(descending) >> digitBitshift) & mask; // next digit (index)
136         //             if (r == rN) // if permutation cycle closed (back to same digit)
137         //                 break;
138         //             immutable ai = binStat[rN].pop_back(); // array index
139         //             swap(input[ai], e0); // do swap
140         //         }
141         //         input[i0] = e0;         // complete cycle
142         //     }
143         // }
144 
145         // TODO: copy reorder algorithm into local function that calls itself in the recursion step
146         // TODO: call this local function
147 
148         assert(input.isSorted!"a < b");
149     }
150     else                        // standard radix sort
151     {
152         // non-in-place requires temporary `y`. TODO: we could allocate these as
153         // a stack-allocated array for small arrays and gain extra speed.
154         auto tempStorage = FixedDynamicArray!E.makeUninitializedOfLength(n);
155         auto tempSlice = tempStorage[];
156 
157         static if (doDigitDiscardal)
158         {
159             UE ors  = 0;         // digits diff(xor)-or-sum
160         }
161 
162         foreach (immutable digitOffset; 0 .. digitCount) // for each `digitOffset` (in base `radix`) starting with least significant (LSD-first)
163         {
164             immutable digitBitshift = digitOffset*radixBitCount;   // digit bit shift
165 
166             static if (doDigitDiscardal)
167             {
168                 if (digitOffset != 0) // if first iteration already performed we have bit statistics
169                 {
170                     if ((! ((ors >> digitBitshift) & mask))) // if bits in digit[d] are either all \em zero or
171                     {
172                         continue;               // no sorting is needed for this digit
173                     }
174                 }
175             }
176 
177             // calculate counts
178             size_t[radix] highOffsets; // histogram buckets count and later upper-limits/walls for values in `input`
179             UE previousUnsignedValue = cast(UE)input[0].bijectToUnsigned(descending);
180             foreach (immutable j; 0 .. n) // for each element index `j` in `input`
181             {
182                 immutable UE currentUnsignedValue = cast(UE)input[j].bijectToUnsigned(descending);
183                 static if (doDigitDiscardal)
184                 {
185                     if (digitOffset == 0) // first iteration calculates statistics
186                     {
187                         ors |= previousUnsignedValue ^ currentUnsignedValue; // accumulate bit change statistics
188                         // ors |= currentUnsignedValue; // accumulate bits statistics
189                     }
190                 }
191                 immutable i = (currentUnsignedValue >> digitBitshift) & mask; // digit (index)
192                 ++highOffsets[i];              // increase histogram bin counter
193                 previousUnsignedValue = currentUnsignedValue;
194             }
195 
196             static if (doDigitDiscardal)
197             {
198                 if (digitOffset == 0) // if first iteration already performed we have bit statistics
199                 {
200                     if ((! ((ors >> digitBitshift) & mask))) // if bits in digit[d] are either all \em zero or
201                     {
202                         continue;               // no sorting is needed for this digit
203                     }
204                 }
205             }
206 
207             // bin boundaries: accumulate bin counters array
208             foreach (immutable j; 1 .. radix) // for each successive bin counter
209             {
210                 highOffsets[j] += highOffsets[j - 1]; // accumulate bin counter
211             }
212             assert(highOffsets[radix - 1] == n); // should equal high offset of last bin
213 
214             // reorder. access `input`'s elements in \em reverse to \em reuse filled caches from previous forward iteration.
215             // \em stable reorder from `input` to `tempSlice` using normal counting sort (see `counting_sort` above).
216             enum unrollFactor = 1;
217             assert((n % unrollFactor) == 0, "TODO: Add reordering for remainder"); // is evenly divisible by unroll factor
218             for (size_t j = n - 1; j < n; j -= unrollFactor) // for each element `j` in reverse order. when `j` wraps around `j` < `n` is no longer true
219             {
220                 static foreach (k; 0 .. unrollFactor) // inlined (unrolled) loop
221                 {
222                     immutable i = (input[j - k].bijectToUnsigned(descending) >> digitBitshift) & mask; // digit (index)
223                     tempSlice[--highOffsets[i]] = input[j - k]; // reorder into tempSlice
224                 }
225             }
226             assert(highOffsets[0] == 0); // should equal low offset of first bin
227 
228             static if (digitCount & 1) // if odd number of digit passes
229             {
230                 static if (__traits(compiles, input[] == tempSlice[]))
231                 {
232                     input[] = tempSlice[]; // faster than std.algorithm.copy() because input never overlap tempSlice
233                 }
234                 else
235                 {
236                     import std.algorithm.mutation : copy;
237                     copy(tempSlice[], input[]); // TODO: use memcpy
238                 }
239             }
240             else
241             {
242                 import std.algorithm.mutation : swap;
243                 swap(input, tempSlice);
244             }
245         }
246     }
247 
248     static if (descending)
249     {
250         return input.assumeSorted!"a > b";
251     }
252     else
253     {
254         return input.assumeSorted!"a < b";
255     }
256 }
257 
258 version = benchmark;
259 // version = show;
260 
261 version(benchmark)
262 @safe unittest
263 {
264     version(show) import std.stdio : write, writef, writeln;
265 
266     /** Test `radixSort` with element-type `E`. */
267     void test(E)(int n) @safe
268     {
269         version(show) writef("%8-s, %10-s, ", E.stringof, n);
270 
271         import nxt.dynamic_array : Array = DynamicArray;
272 
273         import std.traits : isIntegral, isSigned, isUnsigned;
274         import nxt.random_ex : randInPlace, randInPlaceWithElementRange;
275         import std.algorithm.sorting : sort, isSorted;
276         import std.algorithm.mutation : SwapStrategy;
277         import std.algorithm.comparison : min, max, equal;
278         import std.range : retro;
279         import std.datetime.stopwatch : StopWatch, AutoStart;
280         auto sw = StopWatch();
281         immutable nMax = 5;
282 
283         // generate random
284         auto a = Array!E.withLength(n);
285         static if (isUnsigned!E)
286         {
287             // a[].randInPlaceWithElementRange(cast(E)0, cast(E)uint.max);
288             a[].randInPlace();
289         }
290         else
291         {
292             a[].randInPlace();
293         }
294         version(show) write("original random: ", a[0 .. min(nMax, $)], ", ");
295 
296         // standard quick sort
297         auto qa = a.dup;
298 
299         sw.reset;
300         sw.start();
301         qa[].sort!("a < b", SwapStrategy.stable)();
302         sw.stop;
303         immutable sortTimeUsecs = sw.peek.total!"usecs";
304         version(show) write("quick sorted: ", qa[0 .. min(nMax, $)], ", ");
305         assert(qa[].isSorted);
306 
307         // reverse radix sort
308         {
309             auto b = a.dup;
310             b[].radixSort!(typeof(b[]), "a", true)();
311             version(show) write("reverse radix sorted: ", b[0 .. min(nMax, $)], ", ");
312             assert(b[].retro.equal(qa[]));
313         }
314 
315         // standard radix sort
316         {
317             auto b = a.dup;
318 
319             sw.reset;
320             sw.start();
321             b[].radixSort!(typeof(b[]), "b", false)();
322             sw.stop;
323             immutable radixTime1 = sw.peek.total!"usecs";
324 
325             version(show) writef("%9-s, ", cast(real)sortTimeUsecs / radixTime1);
326             assert(b[].equal(qa[]));
327         }
328 
329         // standard radix sort fast-discardal
330         {
331             auto b = a.dup;
332 
333             sw.reset;
334             sw.start();
335             b[].radixSort!(typeof(b[]), "b", false, true)();
336             sw.stop;
337             immutable radixTime = sw.peek.total!"usecs";
338 
339             assert(b[].equal(qa[]));
340 
341             version(show)
342             {
343                 writeln("standard radix sorted with fast-discardal: ",
344                         b[0 .. min(nMax, $)]);
345             }
346             version(show) writef("%9-s, ", cast(real)sortTimeUsecs / radixTime);
347         }
348 
349         // inplace-place radix sort
350         // static if (is(E == uint))
351         // {
352         //     auto b = a.dup;
353 
354         //     sw.reset;
355         //     sw.start();
356         //     b[].radixSort!(typeof(b[]), "b", false, false, true)();
357         //     sw.stop;
358         //     immutable radixTime = sw.peek.usecs;
359 
360         //     assert(b[].equal(qa[]));
361 
362         //     version(show)
363         //     {
364         //         writeln("in-place radix sorted with fast-discardal: ",
365         //                 b[0 .. min(nMax, $)]);
366         //     }
367         //     writef("%9-s, ", cast(real)sortTimeUsecs / radixTime);
368         // }
369 
370         version(show) writeln("");
371     }
372 
373     import std.meta : AliasSeq;
374     immutable n = 1_00_000;
375     version(show) writeln("EType, eCount, radixSort (speed-up), radixSort with fast discardal (speed-up), in-place radixSort (speed-up)");
376     foreach (immutable ix, T; AliasSeq!(byte, short, int, long))
377     {
378         test!T(n); // test signed
379         import std.traits : Unsigned;
380         test!(Unsigned!T)(n); // test unsigned
381     }
382     test!float(n);
383     test!double(n);
384 }
385 
386 @safe unittest
387 {
388     import std.meta : AliasSeq;
389 
390     immutable n = 10_000;
391 
392     foreach (ix, E; AliasSeq!(byte, ubyte,
393                               short, ushort,
394                               int, uint,
395                               long, ulong,
396                               float, double))
397     {
398         import nxt.dynamic_array : Array = DynamicArray;
399         import std.algorithm.sorting : sort, isSorted;
400         import std.algorithm.mutation : swap;
401         import nxt.random_ex : randInPlace;
402 
403         auto a = Array!E.withLength(n);
404 
405         a[].randInPlace();
406         auto b = a.dup;
407         assert(a[].radixSort() == b[].sort());
408 
409         swap(a, b);
410     }
411 }