1 #!/usr/bin/env rdmd-dev
2 
3 /** Integer Sorting Algorithms.
4     Copyright: Per Nordlöw 2018-.
5     License: $(WEB boost.org/LICENSE_1_0.txt, Boost License 1.0).
6     Authors: $(WEB Per Nordlöw)
7  */
8 module nxt.integer_sorting;
9 
10 import std.range.primitives : isRandomAccessRange, ElementType;
11 import std.traits : isNumeric;
12 import std.meta : AliasSeq;
13 
14 import nxt.bijections;
15 
16 /** Radix sort of `input`.
17 
18     Note that this implementation of non-inplace radix sort only requires
19     `input` to be a `BidirectionalRange` not a `RandomAccessRange`.
20 
21     Note that `input` can be a `BidirectionalRange` aswell as
22     `RandomAccessRange`.
23 
24     `radixBitCount` is the number of bits in radix (digit)
25 
26     TODO make `radixBitCount` a template parameter either 8 or 16,
27     ElementType.sizeof must be a multiple of radixBitCount
28 
29     TODO input[] = y[] not needed when input is mutable
30 
31     TODO Restrict fun.
32 
33     TODO Choose fastDigitDiscardal based on elementMin and elementMax (if they
34     are given)
35 
36     See_Also: https://probablydance.com/2016/12/27/i-wrote-a-faster-sorting-algorithm/
37     See_Also: https://github.com/skarupke/ska_sort/blob/master/ska_sort.hpp
38     See_Also: http://forum.dlang.org/thread/vmytpazcusauxypkwdbn@forum.dlang.org#post-vmytpazcusauxypkwdbn:40forum.dlang.org
39  */
40 auto radixSort(R,
41                alias fun = "a",
42                bool descending = false,
43                bool requestDigitDiscardal = false,
44                bool inPlace = false)(R input,
45                                      /* ElementType!R elementMin = ElementType!(R).max, */
46                                      /* ElementType!R elementMax = ElementType!(R).min */)
47 
48     @trusted
49 if (isRandomAccessRange!R &&
50     (isNumeric!(ElementType!R)))
51 {
52     import std.range : assumeSorted;
53     import std.algorithm.sorting : isSorted; // TODO move this to radixSort when know how map less to descending
54     import std.algorithm.comparison : min, max;
55     import std.range.primitives : front;
56 
57     immutable n = input.length; // number of elements
58     alias E = ElementType!R;
59     enum elementBitCount = 8*E.sizeof; // total number of bits needed to code each element
60 
61     /* Lookup number of radix bits from sizeof `ElementType`.
62        These give optimal performance on Intel Core i7.
63     */
64     static if (elementBitCount == 8 ||
65                elementBitCount == 24)
66     {
67         enum radixBitCount = 8;
68     }
69     else static if (elementBitCount == 16 ||
70                     elementBitCount == 32 ||
71                     elementBitCount == 64)
72     {
73         enum radixBitCount = 16;
74     }
75     else
76     {
77         static assert(0, "TODO handle element type " ~ e.stringof);
78     }
79 
80     // TODO activate this: subtract min from all values and then immutable uint elementBitCount = is_min(a_max) ? 8*sizeof(E) : binlog(a_max); and add it back.
81     enum digitCount = elementBitCount / radixBitCount;         // number of `digitCount` in radix `radixBitCount`
82     static assert(elementBitCount % radixBitCount == 0,
83                   "Precision of ElementType must be evenly divisble by bit-precision of Radix.");
84 
85     enum doDigitDiscardal = requestDigitDiscardal && digitCount >= 2;
86 
87     enum radix = cast(typeof(radixBitCount))1 << radixBitCount;    // bin count
88     enum mask = radix-1;                                     // radix bit mask
89 
90     alias UE = typeof(input.front.bijectToUnsigned); // get unsigned integer type of same precision as \tparam E.
91 
92     import nxt.fixed_dynamic_array : FixedDynamicArray;
93 
94     static if (inPlace) // most-significant digit (MSD) first in-place radix sort
95     {
96         static assert(!descending, "TODO Implement descending version");
97 
98         foreach (immutable digitOffsetReversed; 0 .. digitCount) // for each `digitOffset` (in base `radix`) starting with least significant (LSD-first)
99         {
100             immutable digitOffset = digitCount - 1 - digitOffsetReversed;
101             immutable digitBitshift = digitOffset*radixBitCount; // digit bit shift
102 
103             // [lowOffsets[i], highOffsets[i]] will become slices into `input`
104             size_t[radix] lowOffsets; // low offsets for each bin
105             size_t[radix] highOffsets; // high offsets for each bin
106 
107             // calculate counts
108             foreach (immutable j; 0 .. n) // for each element index `j` in `input`
109             {
110                 immutable UE currentUnsignedValue = cast(UE)input[j].bijectToUnsigned(descending);
111                 immutable i = (currentUnsignedValue >> digitBitshift) & mask; // digit (index)
112                 ++highOffsets[i];   // increase histogram bin counter
113             }
114 
115             // bin boundaries: accumulate bin counters array
116             lowOffsets[0] = 0;             // first low is always zero
117             foreach (immutable j; 1 .. radix) // for each successive bin counter
118             {
119                 lowOffsets[j] = highOffsets[j - 1]; // previous roof becomes current floor
120                 highOffsets[j] += highOffsets[j - 1]; // accumulate bin counter
121             }
122             assert(highOffsets[radix - 1] == n); // should equal high offset of last bin
123         }
124 
125         // /** \em unstable in-place (permutate) reorder/sort `input`
126         //  * access `input`'s elements in \em reverse to \em reuse filled caches from previous forward iteration.
127         //  * \see `in_place_indexed_reorder`
128         //  */
129         // for (int r = radix - 1; r >= 0; --r) // for each radix digit r in reverse order (cache-friendly)
130         // {
131         //     while (binStat[r])  // as long as elements left in r:th bucket
132         //     {
133         //         immutable uint i0 = binStat[r].pop_back(); // index to first element of permutation
134         //         immutable E    e0 = input[i0]; // value of first/current element of permutation
135         //         while (true)
136         //         {
137         //             immutable int rN = (e0.bijectToUnsigned(descending) >> digitBitshift) & mask; // next digit (index)
138         //             if (r == rN) // if permutation cycle closed (back to same digit)
139         //                 break;
140         //             immutable ai = binStat[rN].pop_back(); // array index
141         //             swap(input[ai], e0); // do swap
142         //         }
143         //         input[i0] = e0;         // complete cycle
144         //     }
145         // }
146 
147         // TODO copy reorder algorithm into local function that calls itself in the recursion step
148         // TODO call this local function
149 
150         assert(input.isSorted!"a < b");
151     }
152     else                        // standard radix sort
153     {
154         // non-in-place requires temporary `y`. TODO we could allocate these as
155         // a stack-allocated array for small arrays and gain extra speed.
156         auto tempStorage = FixedDynamicArray!E.makeUninitializedOfLength(n);
157         auto tempSlice = tempStorage[];
158 
159         static if (doDigitDiscardal)
160         {
161             UE ors  = 0;         // digits diff(xor)-or-sum
162         }
163 
164         foreach (immutable digitOffset; 0 .. digitCount) // for each `digitOffset` (in base `radix`) starting with least significant (LSD-first)
165         {
166             immutable digitBitshift = digitOffset*radixBitCount;   // digit bit shift
167 
168             static if (doDigitDiscardal)
169             {
170                 if (digitOffset != 0) // if first iteration already performed we have bit statistics
171                 {
172                     if ((! ((ors >> digitBitshift) & mask))) // if bits in digit[d] are either all \em zero or
173                     {
174                         continue;               // no sorting is needed for this digit
175                     }
176                 }
177             }
178 
179             // calculate counts
180             size_t[radix] highOffsets; // histogram buckets count and later upper-limits/walls for values in `input`
181             UE previousUnsignedValue = cast(UE)input[0].bijectToUnsigned(descending);
182             foreach (immutable j; 0 .. n) // for each element index `j` in `input`
183             {
184                 immutable UE currentUnsignedValue = cast(UE)input[j].bijectToUnsigned(descending);
185                 static if (doDigitDiscardal)
186                 {
187                     if (digitOffset == 0) // first iteration calculates statistics
188                     {
189                         ors |= previousUnsignedValue ^ currentUnsignedValue; // accumulate bit change statistics
190                         // ors |= currentUnsignedValue; // accumulate bits statistics
191                     }
192                 }
193                 immutable i = (currentUnsignedValue >> digitBitshift) & mask; // digit (index)
194                 ++highOffsets[i];              // increase histogram bin counter
195                 previousUnsignedValue = currentUnsignedValue;
196             }
197 
198             static if (doDigitDiscardal)
199             {
200                 if (digitOffset == 0) // if first iteration already performed we have bit statistics
201                 {
202                     if ((! ((ors >> digitBitshift) & mask))) // if bits in digit[d] are either all \em zero or
203                     {
204                         continue;               // no sorting is needed for this digit
205                     }
206                 }
207             }
208 
209             // bin boundaries: accumulate bin counters array
210             foreach (immutable j; 1 .. radix) // for each successive bin counter
211             {
212                 highOffsets[j] += highOffsets[j - 1]; // accumulate bin counter
213             }
214             assert(highOffsets[radix - 1] == n); // should equal high offset of last bin
215 
216             // reorder. access `input`'s elements in \em reverse to \em reuse filled caches from previous forward iteration.
217             // \em stable reorder from `input` to `tempSlice` using normal counting sort (see `counting_sort` above).
218             enum unrollFactor = 1;
219             assert((n % unrollFactor) == 0, "TODO Add reordering for remainder"); // is evenly divisible by unroll factor
220             for (size_t j = n - 1; j < n; j -= unrollFactor) // for each element `j` in reverse order. when `j` wraps around `j` < `n` is no longer true
221             {
222                 static foreach (k; 0 .. unrollFactor) // inlined (unrolled) loop
223                 {
224                     immutable i = (input[j - k].bijectToUnsigned(descending) >> digitBitshift) & mask; // digit (index)
225                     tempSlice[--highOffsets[i]] = input[j - k]; // reorder into tempSlice
226                 }
227             }
228             assert(highOffsets[0] == 0); // should equal low offset of first bin
229 
230             static if (digitCount & 1) // if odd number of digit passes
231             {
232                 static if (__traits(compiles, input[] == tempSlice[]))
233                 {
234                     input[] = tempSlice[]; // faster than std.algorithm.copy() because input never overlap tempSlice
235                 }
236                 else
237                 {
238                     import std.algorithm.mutation : copy;
239                     copy(tempSlice[], input[]); // TODO use memcpy
240                 }
241             }
242             else
243             {
244                 import std.algorithm.mutation : swap;
245                 swap(input, tempSlice);
246             }
247         }
248     }
249 
250     static if (descending)
251     {
252         return input.assumeSorted!"a > b";
253     }
254     else
255     {
256         return input.assumeSorted!"a < b";
257     }
258 }
259 
260 version = benchmark;
261 // version = show;
262 
263 version(benchmark)
264 @safe unittest
265 {
266     version(show) import std.stdio : write, writef, writeln;
267 
268     /** Test `radixSort` with element-type `E`. */
269     void test(E)(int n) @safe
270     {
271         version(show) writef("%8-s, %10-s, ", E.stringof, n);
272 
273         import nxt.dynamic_array : Array = DynamicArray;
274 
275         import std.traits : isIntegral, isSigned, isUnsigned;
276         import nxt.random_ex : randInPlace, randInPlaceWithElementRange;
277         import std.algorithm.sorting : sort, isSorted;
278         import std.algorithm.mutation : SwapStrategy;
279         import std.algorithm.comparison : min, max, equal;
280         import std.range : retro;
281         import std.datetime.stopwatch : StopWatch, AutoStart;
282         auto sw = StopWatch();
283         immutable nMax = 5;
284 
285         // generate random
286         auto a = Array!E.withLength(n);
287         static if (isUnsigned!E)
288         {
289             // a[].randInPlaceWithElementRange(cast(E)0, cast(E)uint.max);
290             a[].randInPlace();
291         }
292         else
293         {
294             a[].randInPlace();
295         }
296         version(show) write("original random: ", a[0 .. min(nMax, $)], ", ");
297 
298         // standard quick sort
299         auto qa = a.dup;
300 
301         sw.reset;
302         sw.start();
303         qa[].sort!("a < b", SwapStrategy.stable)();
304         sw.stop;
305         immutable sortTimeUsecs = sw.peek.total!"usecs";
306         version(show) write("quick sorted: ", qa[0 .. min(nMax, $)], ", ");
307         assert(qa[].isSorted);
308 
309         // reverse radix sort
310         {
311             auto b = a.dup;
312             b[].radixSort!(typeof(b[]), "a", true)();
313             version(show) write("reverse radix sorted: ", b[0 .. min(nMax, $)], ", ");
314             assert(b[].retro.equal(qa[]));
315         }
316 
317         // standard radix sort
318         {
319             auto b = a.dup;
320 
321             sw.reset;
322             sw.start();
323             b[].radixSort!(typeof(b[]), "b", false)();
324             sw.stop;
325             immutable radixTime1 = sw.peek.total!"usecs";
326 
327             version(show) writef("%9-s, ", cast(real)sortTimeUsecs / radixTime1);
328             assert(b[].equal(qa[]));
329         }
330 
331         // standard radix sort fast-discardal
332         {
333             auto b = a.dup;
334 
335             sw.reset;
336             sw.start();
337             b[].radixSort!(typeof(b[]), "b", false, true)();
338             sw.stop;
339             immutable radixTime = sw.peek.total!"usecs";
340 
341             assert(b[].equal(qa[]));
342 
343             version(show)
344             {
345                 writeln("standard radix sorted with fast-discardal: ",
346                         b[0 .. min(nMax, $)]);
347             }
348             version(show) writef("%9-s, ", cast(real)sortTimeUsecs / radixTime);
349         }
350 
351         // inplace-place radix sort
352         // static if (is(E == uint))
353         // {
354         //     auto b = a.dup;
355 
356         //     sw.reset;
357         //     sw.start();
358         //     b[].radixSort!(typeof(b[]), "b", false, false, true)();
359         //     sw.stop;
360         //     immutable radixTime = sw.peek.usecs;
361 
362         //     assert(b[].equal(qa[]));
363 
364         //     version(show)
365         //     {
366         //         writeln("in-place radix sorted with fast-discardal: ",
367         //                 b[0 .. min(nMax, $)]);
368         //     }
369         //     writef("%9-s, ", cast(real)sortTimeUsecs / radixTime);
370         // }
371 
372         version(show) writeln("");
373     }
374 
375     import std.meta : AliasSeq;
376     immutable n = 1_00_000;
377     version(show) writeln("EType, eCount, radixSort (speed-up), radixSort with fast discardal (speed-up), in-place radixSort (speed-up)");
378     foreach (immutable ix, T; AliasSeq!(byte, short, int, long))
379     {
380         test!T(n); // test signed
381         import std.traits : Unsigned;
382         test!(Unsigned!T)(n); // test unsigned
383     }
384     test!float(n);
385     test!double(n);
386 }
387 
388 @safe unittest
389 {
390     import std.meta : AliasSeq;
391 
392     immutable n = 10_000;
393 
394     foreach (ix, E; AliasSeq!(byte, ubyte,
395                               short, ushort,
396                               int, uint,
397                               long, ulong,
398                               float, double))
399     {
400         import nxt.dynamic_array : Array = DynamicArray;
401         import std.algorithm.sorting : sort, isSorted;
402         import std.algorithm.mutation : swap;
403         import nxt.random_ex : randInPlace;
404 
405         auto a = Array!E.withLength(n);
406 
407         a[].randInPlace();
408         auto b = a.dup;
409         assert(a[].radixSort() == b[].sort());
410 
411         swap(a, b);
412     }
413 }