1 /** Integer Sorting Algorithms.
2     Copyright: Per Nordlöw 2018-.
3     License: $(WEB boost.org/LICENSE_1_0.txt, Boost License 1.0).
4     Authors: $(WEB Per Nordlöw)
5  */
6 module nxt.integer_sorting;
7 
8 import std.range.primitives : isRandomAccessRange, ElementType;
9 import std.traits : isNumeric;
10 import std.meta : AliasSeq;
11 
12 import nxt.bijections;
13 
14 /** Radix sort of `input`.
15 
16     Note that this implementation of non-inplace radix sort only requires
17     `input` to be a `BidirectionalRange` not a `RandomAccessRange`.
18 
19     Note that `input` can be a `BidirectionalRange` aswell as
20     `RandomAccessRange`.
21 
22     `radixBitCount` is the number of bits in radix (digit)
23 
24     TODO: make `radixBitCount` a template parameter either 8 or 16,
25     ElementType.sizeof must be a multiple of radixBitCount
26 
27     TODO: input[] = y[] not needed when input is mutable
28 
29     TODO: Restrict fun.
30 
31     TODO: Choose fastDigitDiscardal based on elementMin and elementMax (if they
32     are given)
33 
34     See_Also: https://probablydance.com/2016/12/27/i-wrote-a-faster-sorting-algorithm/
35     See_Also: https://github.com/skarupke/ska_sort/blob/master/ska_sort.hpp
36     See_Also: http://forum.dlang.org/thread/vmytpazcusauxypkwdbn@forum.dlang.org#post-vmytpazcusauxypkwdbn:40forum.dlang.org
37  */
38 auto radixSort(R,
39                alias fun = "a",
40                bool descending = false,
41                bool requestDigitDiscardal = false,
42                bool inPlace = false)(R input,
43                                      /* ElementType!R elementMin = ElementType!(R).max, */
44                                      /* ElementType!R elementMax = ElementType!(R).min */)
45 
46     @trusted
47 if (isRandomAccessRange!R &&
48     (isNumeric!(ElementType!R)))
49 {
50     import std.range : assumeSorted;
51     import std.algorithm.sorting : isSorted; // TODO: move this to radixSort when know how map less to descending
52     import std.algorithm.comparison : min, max;
53     import std.range.primitives : front;
54 
55     immutable n = input.length; // number of elements
56     alias E = ElementType!R;
57     enum elementBitCount = 8*E.sizeof; // total number of bits needed to code each element
58 
59     /* Lookup number of radix bits from sizeof `ElementType`.
60        These give optimal performance on Intel Core i7.
61     */
62     static if (elementBitCount == 8 ||
63                elementBitCount == 24)
64         enum radixBitCount = 8;
65     else static if (elementBitCount == 16 ||
66                     elementBitCount == 32 ||
67                     elementBitCount == 64)
68         enum radixBitCount = 16;
69     else
70         static assert(0, "TODO: handle element type " ~ e.stringof);
71 
72     // TODO: activate this: subtract min from all values and then immutable uint elementBitCount = is_min(a_max) ? 8*sizeof(E) : binlog(a_max); and add it back.
73     enum digitCount = elementBitCount / radixBitCount;         // number of `digitCount` in radix `radixBitCount`
74     static assert(elementBitCount % radixBitCount == 0,
75                   "Precision of ElementType must be evenly divisble by bit-precision of Radix.");
76 
77     enum doDigitDiscardal = requestDigitDiscardal && digitCount >= 2;
78 
79     enum radix = cast(typeof(radixBitCount))1 << radixBitCount;    // bin count
80     enum mask = radix-1;                                     // radix bit mask
81 
82     alias UE = typeof(input.front.bijectToUnsigned); // get unsigned integer type of same precision as \tparam E.
83 
84     import nxt.container.fixed_dynamic_array : FixedDynamicArray;
85 
86     static if (inPlace) // most-significant digit (MSD) first in-place radix sort
87     {
88         static assert(!descending, "TODO: Implement descending version");
89 
90         foreach (immutable digitOffsetReversed; 0 .. digitCount) // for each `digitOffset` (in base `radix`) starting with least significant (LSD-first)
91         {
92             immutable digitOffset = digitCount - 1 - digitOffsetReversed;
93             immutable digitBitshift = digitOffset*radixBitCount; // digit bit shift
94 
95             // [lowOffsets[i], highOffsets[i]] will become slices into `input`
96             size_t[radix] lowOffsets; // low offsets for each bin
97             size_t[radix] highOffsets; // high offsets for each bin
98 
99             // calculate counts
100             foreach (immutable j; 0 .. n) // for each element index `j` in `input`
101             {
102                 immutable UE currentUnsignedValue = cast(UE)input[j].bijectToUnsigned(descending);
103                 immutable i = (currentUnsignedValue >> digitBitshift) & mask; // digit (index)
104                 ++highOffsets[i];   // increase histogram bin counter
105             }
106 
107             // bin boundaries: accumulate bin counters array
108             lowOffsets[0] = 0;             // first low is always zero
109             foreach (immutable j; 1 .. radix) // for each successive bin counter
110             {
111                 lowOffsets[j] = highOffsets[j - 1]; // previous roof becomes current floor
112                 highOffsets[j] += highOffsets[j - 1]; // accumulate bin counter
113             }
114             assert(highOffsets[radix - 1] == n); // should equal high offset of last bin
115         }
116 
117         // /** \em unstable in-place (permutate) reorder/sort `input`
118         //  * access `input`'s elements in \em reverse to \em reuse filled caches from previous forward iteration.
119         //  * \see `in_place_indexed_reorder`
120         //  */
121         // for (int r = radix - 1; r >= 0; --r) // for each radix digit r in reverse order (cache-friendly)
122         // {
123         //     while (binStat[r])  // as long as elements left in r:th bucket
124         //     {
125         //         immutable uint i0 = binStat[r].pop_back(); // index to first element of permutation
126         //         immutable E    e0 = input[i0]; // value of first/current element of permutation
127         //         while (true)
128         //         {
129         //             immutable int rN = (e0.bijectToUnsigned(descending) >> digitBitshift) & mask; // next digit (index)
130         //             if (r == rN) // if permutation cycle closed (back to same digit)
131         //                 break;
132         //             immutable ai = binStat[rN].pop_back(); // array index
133         //             swap(input[ai], e0); // do swap
134         //         }
135         //         input[i0] = e0;         // complete cycle
136         //     }
137         // }
138 
139         // TODO: copy reorder algorithm into local function that calls itself in the recursion step
140         // TODO: call this local function
141 
142         assert(input.isSorted!"a < b");
143     }
144     else                        // standard radix sort
145     {
146         // non-in-place requires temporary `y`. TODO: we could allocate these as
147         // a stack-allocated array for small arrays and gain extra speed.
148         auto tempStorage = FixedDynamicArray!E.makeUninitializedOfLength(n);
149         auto tempSlice = tempStorage[];
150 
151         static if (doDigitDiscardal)
152         {
153             UE ors  = 0;         // digits diff(xor)-or-sum
154         }
155 
156         foreach (immutable digitOffset; 0 .. digitCount) // for each `digitOffset` (in base `radix`) starting with least significant (LSD-first)
157         {
158             immutable digitBitshift = digitOffset*radixBitCount;   // digit bit shift
159 
160             static if (doDigitDiscardal)
161                 if (digitOffset != 0) // if first iteration already performed we have bit statistics
162                     if ((! ((ors >> digitBitshift) & mask))) // if bits in digit[d] are either all \em zero or
163                         continue;               // no sorting is needed for this digit
164 
165             // calculate counts
166             size_t[radix] highOffsets; // histogram buckets count and later upper-limits/walls for values in `input`
167             UE previousUnsignedValue = cast(UE)input[0].bijectToUnsigned(descending);
168             foreach (immutable j; 0 .. n) // for each element index `j` in `input`
169             {
170                 immutable UE currentUnsignedValue = cast(UE)input[j].bijectToUnsigned(descending);
171                 static if (doDigitDiscardal)
172                     if (digitOffset == 0) // first iteration calculates statistics
173                     {
174                         ors |= previousUnsignedValue ^ currentUnsignedValue; // accumulate bit change statistics
175                         // ors |= currentUnsignedValue; // accumulate bits statistics
176                     }
177                 immutable i = (currentUnsignedValue >> digitBitshift) & mask; // digit (index)
178                 ++highOffsets[i];              // increase histogram bin counter
179                 previousUnsignedValue = currentUnsignedValue;
180             }
181 
182             static if (doDigitDiscardal)
183                 if (digitOffset == 0) // if first iteration already performed we have bit statistics
184                     if ((! ((ors >> digitBitshift) & mask))) // if bits in digit[d] are either all \em zero or
185                         continue;               // no sorting is needed for this digit
186 
187             // bin boundaries: accumulate bin counters array
188             foreach (immutable j; 1 .. radix) // for each successive bin counter
189                 highOffsets[j] += highOffsets[j - 1]; // accumulate bin counter
190             assert(highOffsets[radix - 1] == n); // should equal high offset of last bin
191 
192             // reorder. access `input`'s elements in \em reverse to \em reuse filled caches from previous forward iteration.
193             // \em stable reorder from `input` to `tempSlice` using normal counting sort (see `counting_sort` above).
194             enum unrollFactor = 1;
195             assert((n % unrollFactor) == 0, "TODO: Add reordering for remainder"); // is evenly divisible by unroll factor
196             for (size_t j = n - 1; j < n; j -= unrollFactor) // for each element `j` in reverse order. when `j` wraps around `j` < `n` is no longer true
197             {
198                 static foreach (k; 0 .. unrollFactor) // inlined (unrolled) loop
199                 {
200                     immutable i = (input[j - k].bijectToUnsigned(descending) >> digitBitshift) & mask; // digit (index)
201                     tempSlice[--highOffsets[i]] = input[j - k]; // reorder into tempSlice
202                 }
203             }
204             assert(highOffsets[0] == 0); // should equal low offset of first bin
205 
206             static if (digitCount & 1) // if odd number of digit passes
207             {
208                 static if (__traits(compiles, input[] == tempSlice[]))
209                     input[] = tempSlice[]; // faster than std.algorithm.copy() because input never overlap tempSlice
210                 else
211                 {
212                     import std.algorithm.mutation : copy;
213                     copy(tempSlice[], input[]); // TODO: use memcpy
214                 }
215             }
216             else
217             {
218                 import std.algorithm.mutation : swap;
219                 swap(input, tempSlice);
220             }
221         }
222     }
223 
224     static if (descending)
225         return input.assumeSorted!"a > b";
226     else
227         return input.assumeSorted!"a < b";
228 }
229 
230 version = benchmark;
231 // version = show;
232 
233 version(benchmark)
234 @safe unittest
235 {
236     version(show) import std.stdio : write, writef, writeln;
237 
238     /** Test `radixSort` with element-type `E`. */
239     void test(E)(int n) @safe
240     {
241         version(show) writef("%8-s, %10-s, ", E.stringof, n);
242 
243         import nxt.container.dynamic_array : Array = DynamicArray;
244 
245         import std.traits : isIntegral, isSigned, isUnsigned;
246         import nxt.random_ex : randInPlace, randInPlaceWithElementRange;
247         import std.algorithm.sorting : sort, isSorted;
248         import std.algorithm.mutation : SwapStrategy;
249         import std.algorithm.comparison : min, max, equal;
250         import std.range : retro;
251         import std.datetime.stopwatch : StopWatch, AutoStart;
252         auto sw = StopWatch();
253         immutable nMax = 5;
254 
255         // generate random
256         auto a = Array!E.withLength(n);
257         static if (isUnsigned!E)
258         {
259             // a[].randInPlaceWithElementRange(cast(E)0, cast(E)uint.max);
260             a[].randInPlace();
261         }
262         else
263         {
264             a[].randInPlace();
265         }
266         version(show) write("original random: ", a[0 .. min(nMax, $)], ", ");
267 
268         // standard quick sort
269         auto qa = a.dup;
270 
271         sw.reset;
272         sw.start();
273         qa[].sort!("a < b", SwapStrategy.stable)();
274         sw.stop;
275         immutable sortTimeUsecs = sw.peek.total!"usecs";
276         version(show) write("quick sorted: ", qa[0 .. min(nMax, $)], ", ");
277         assert(qa[].isSorted);
278 
279         // reverse radix sort
280         {
281             auto b = a.dup;
282             b[].radixSort!(typeof(b[]), "a", true)();
283             version(show) write("reverse radix sorted: ", b[0 .. min(nMax, $)], ", ");
284             assert(b[].retro.equal(qa[]));
285         }
286 
287         // standard radix sort
288         {
289             auto b = a.dup;
290 
291             sw.reset;
292             sw.start();
293             b[].radixSort!(typeof(b[]), "b", false)();
294             sw.stop;
295             immutable radixTime1 = sw.peek.total!"usecs";
296 
297             version(show) writef("%9-s, ", cast(real)sortTimeUsecs / radixTime1);
298             assert(b[].equal(qa[]));
299         }
300 
301         // standard radix sort fast-discardal
302         {
303             auto b = a.dup;
304 
305             sw.reset;
306             sw.start();
307             b[].radixSort!(typeof(b[]), "b", false, true)();
308             sw.stop;
309             immutable radixTime = sw.peek.total!"usecs";
310 
311             assert(b[].equal(qa[]));
312 
313             version(show)
314             {
315                 writeln("standard radix sorted with fast-discardal: ",
316                         b[0 .. min(nMax, $)]);
317             }
318             version(show) writef("%9-s, ", cast(real)sortTimeUsecs / radixTime);
319         }
320 
321         // inplace-place radix sort
322         // static if (is(E == uint))
323         // {
324         //     auto b = a.dup;
325 
326         //     sw.reset;
327         //     sw.start();
328         //     b[].radixSort!(typeof(b[]), "b", false, false, true)();
329         //     sw.stop;
330         //     immutable radixTime = sw.peek.usecs;
331 
332         //     assert(b[].equal(qa[]));
333 
334         //     version(show)
335         //     {
336         //         writeln("in-place radix sorted with fast-discardal: ",
337         //                 b[0 .. min(nMax, $)]);
338         //     }
339         //     writef("%9-s, ", cast(real)sortTimeUsecs / radixTime);
340         // }
341 
342         version(show) writeln("");
343     }
344 
345     import std.meta : AliasSeq;
346     immutable n = 1_00_000;
347     version(show) writeln("EType, eCount, radixSort (speed-up), radixSort with fast discardal (speed-up), in-place radixSort (speed-up)");
348     foreach (immutable ix, T; AliasSeq!(byte, short, int, long))
349     {
350         test!T(n); // test signed
351         import std.traits : Unsigned;
352         test!(Unsigned!T)(n); // test unsigned
353     }
354     test!float(n);
355     test!double(n);
356 }
357 
358 @safe unittest
359 {
360     import std.meta : AliasSeq;
361 
362     immutable n = 10_000;
363 
364     foreach (ix, E; AliasSeq!(byte, ubyte,
365                               short, ushort,
366                               int, uint,
367                               long, ulong,
368                               float, double))
369     {
370         import nxt.container.dynamic_array : Array = DynamicArray;
371         import std.algorithm.sorting : sort, isSorted;
372         import std.algorithm.mutation : swap;
373         import nxt.random_ex : randInPlace;
374 
375         auto a = Array!E.withLength(n);
376 
377         a[].randInPlace();
378         auto b = a.dup;
379         assert(a[].radixSort() == b[].sort());
380 
381         swap(a, b);
382     }
383 }