1 /**
2   Provide a 2^N-bit integer type.
3   Guaranteed to never allocate and expected binary layout
4   Recursive implementation with very slow division.
5 
6   Copied from https://raw.githubusercontent.com/d-gamedev-team/gfm/master/integers/gfm/integers/wideint.d
7 
8   <b>Supports all operations that builtin integers support.</b>
9 
10   TODO: Integrate representations and potential assembly optimizations from
11   https://github.com/ckormanyos/wide-integer
12 
13   See_Also: https://github.com/ckormanyos/wide-integer
14 
15   Bugs: it's not sure if the unsigned operand would take precedence in a comparison/division.
16           - a < b should be an unsigned comparison if at least one operand is unsigned
17           - a / b should be an unsigned division   if at least one operand is unsigned
18  */
19 module nxt.wideint;
20 
21 // version = format;				// Support std.format
22 
23 import std.traits, std.ascii;
24 
25 /// Signed integer of arbitary static precision `bits`.
26 /// Params:
27 ///    bits = number of bits, must be a power of 2.
28 alias SInt(uint bits) = Int!(true, bits);
29 
30 /// Unsigned integer of arbitary static precision `bits`.
31 /// Params:
32 ///    bits = number of bits, must be a power of 2.
33 alias UInt(uint bits) = Int!(false, bits);
34 
35 // Some predefined integers (any power of 2 greater than 128 would work)
36 
37 /// Use this template to get an arbitrary sized integer type.
38 private template Int(bool signed, uint bits)
39 if ((bits & (bits - 1)) == 0)
40 {
41     // forward to native type for lower numbers of bits in order of most probable
42     static if (bits == 64)
43     {
44         static if (signed)
45             alias Int = long;
46         else
47             alias Int = ulong;
48     }
49     else static if (bits == 32)
50     {
51         static if (signed)
52             alias Int = int;
53         else
54             alias Int = uint;
55     }
56     else static if (bits == 16)
57     {
58         static if (signed)
59             alias Int = short;
60         else
61             alias Int = ushort;
62     }
63     else static if (bits == 8)
64     {
65         static if (signed)
66             alias Int = byte;
67         else
68             alias Int = ubyte;
69     }
70     else
71     {
72         alias Int = IntImpl!(signed, bits);
73     }
74 }
75 
76 private template Int(bool signed, uint bits)
77 if (!isPowerOf2(bits))
78 {
79     static assert(0, "Integer bits " ~ bits.stringof ~ " is not a power of two.");
80 }
81 
82 private bool isPowerOf2(in uint x) pure @safe nothrow @nogc
83 {
84 	auto y = cast(typeof(x + 0u))x;
85 	return (y & -y) > (y - 1);
86 }
87 
88 version(unittest)
89 {
90 	static assert(isPowerOf2(2));
91 	static assert(!isPowerOf2(3));
92 	static assert(isPowerOf2(4));
93 	static assert(!isPowerOf2(5));
94 	static assert(!isPowerOf2(7));
95 	static assert(isPowerOf2(8));
96 }
97 
98 /// Recursive 2^n integer implementation.
99 private struct IntImpl(bool signed, uint bits)
100 {
101     static assert(bits >= 128);
102     private
103     {
104         alias Self = typeof(this);
105 		enum bool isSelf(T) = is(Unqual!T == typeof(this));
106 
107         alias sub_int_t = Int!(true, bits/2);   // signed bits/2 integer
108         alias sub_uint_t = Int!(false, bits/2); // unsigned bits/2 integer
109 
110         alias sub_sub_int_t = Int!(true, bits/4);   // signed bits/4 integer
111         alias sub_sub_uint_t = Int!(false, bits/4); // unsigned bits/4 integer
112 
113         static if (signed)
114             alias hi_t = sub_int_t; // hi_t has same signedness as the whole struct
115         else
116             alias hi_t = sub_uint_t;
117 
118         alias low_t = sub_uint_t;   // low_t is always unsigned
119 
120         enum _bits = bits, _signed = signed;
121     }
122 
123     /// Construct from a value.
124     this(T)(T x) pure nothrow @nogc
125     {
126         opAssign!T(x);
127     }
128 
129     // Private functions used by the `literal` template.
130     private static bool isValidDigitString(string digits)
131     {
132         import std.algorithm.searching : startsWith;
133         import std.ascii : isDigit;
134 
135         if (digits.startsWith("0x"))
136         {
137             foreach (const d; digits[2 .. $])
138                 if (!isHexDigit(d) && d != '_')
139                     return false;
140         }
141         else // decimal
142         {
143             static if (signed)
144                 if (digits.startsWith("-"))
145                     digits = digits[1 .. $];
146             if (digits.length < 1)
147                 return false;   // at least 1 digit required
148             foreach (const d; digits)
149                 if (!isDigit(d) && d != '_')
150                     return false;
151         }
152         return true;
153     }
154 
155     private static typeof(this) literalImpl(string digits)
156     {
157         import std.algorithm.searching : startsWith;
158         import std.ascii : isDigit;
159 
160         typeof(this) value = 0;
161         if (digits.startsWith("0x"))
162         {
163             foreach (const d; digits[2 .. $])
164             {
165                 if (d == '_')
166                     continue;
167                 value <<= 4;
168                 if (isDigit(d))
169                     value += d - '0';
170                 else
171                     value += 10 + toUpper(d) - 'A';
172             }
173         }
174         else
175         {
176             static if (signed)
177             {
178                 bool negative = false;
179                 if (digits.startsWith("-"))
180                 {
181                     negative = true;
182                     digits = digits[1 .. $];
183                 }
184             }
185             foreach (const d; digits)
186             {
187                 if (d == '_')
188                     continue;
189                 value *= 10;
190                 value += d - '0';
191             }
192             static if (signed)
193                 if (negative)
194                     value = -value;
195         }
196         return value;
197     }
198 
199     /// Construct from compile-time digit string.
200     ///
201     /// Both decimal and hex digit strings are supported.
202     ///
203     /// Example:
204     /// ----
205     /// auto x = int128_t.literal!"20_000_000_000_000_000_001";
206     /// assert((x >>> 1) == 0x8AC7_2304_89E8_0000);
207     ///
208     /// auto y = int126.literal!"0x1_158E_4609_13D0_0001";
209     /// assert(y == x);
210     /// ----
211     template literal(string digits)
212     {
213         static assert(isValidDigitString(digits),
214                       "invalid digits in literal: " ~ digits);
215         enum literal = literalImpl(digits);
216     }
217 
218     /// Assign with a smaller unsigned type.
219     ref typeof(this) opAssign(T)(T n) pure nothrow @nogc if (isIntegral!T && isUnsigned!T)
220     {
221         hi = 0;
222         lo = n;
223         return this;
224     }
225 
226     /// Assign with a smaller signed type (sign is extended).
227     ref typeof(this) opAssign(T)(T n) pure nothrow @nogc if (isIntegral!T && isSigned!T)
228     {
229         // shorter int always gets sign-extended,
230         // regardless of the larger int being signed or not
231         hi = (n < 0) ? cast(hi_t)(-1) : cast(hi_t)0;
232 
233         // will also sign extend as well if needed
234         lo = cast(sub_int_t)n;
235         return this;
236     }
237 
238     /// Assign with a wide integer of the same size (sign is lost).
239     ref typeof(this) opAssign(T)(T n) pure nothrow @nogc if (isWideIntInstantiation!T && T._bits == bits)
240     {
241         hi = n.hi;
242         lo = n.lo;
243         return this;
244     }
245 
246     /// Assign with a smaller wide integer (sign is extended accordingly).
247     ref typeof(this) opAssign(T)(T n) pure nothrow @nogc if (isWideIntInstantiation!T && T._bits < bits)
248     {
249         static if (T._signed)
250         {
251             // shorter int always gets sign-extended,
252             // regardless of the larger int being signed or not
253             hi = cast(hi_t)((n < 0) ? -1 : 0);
254 
255             // will also sign extend as well if needed
256             lo = cast(sub_int_t)n;
257             return this;
258         }
259         else
260         {
261             hi = 0;
262             lo = n;
263             return this;
264         }
265     }
266 
267     /// Cast to a smaller integer type (truncation).
268     T opCast(T)() pure const nothrow @nogc if (isIntegral!T) => cast(T)lo;
269 
270     /// Cast to bool.
271     T opCast(T)() pure const nothrow @nogc if (is(T == bool)) => this != 0;
272 
273     /// Cast to wide integer of any size.
274     T opCast(T)() pure const nothrow @nogc if (isWideIntInstantiation!T)
275     {
276         static if (T._bits < bits)
277             return cast(T)lo;
278         else
279             return T(this);
280     }
281 
282 	version(format)
283 	{
284 		import std.format : FormatSpec;
285 
286 		/// Converts to a string. Supports format specifiers %d, %s (both decimal)
287 		/// and %x (hex).
288 		void toString(Sink)(Sink sink, in FormatSpec!char fmt) const // TODO: something like is(Sink == scope void delegate(scope const(char)[]) @safe))
289 		{
290 			if (fmt.spec == 'x')
291 				toStringHexadecimal(sink);
292 			else
293 				toStringDecimal(sink);
294 		}
295 	}
296 
297     void toStringHexadecimal(Sink)(Sink sink) const // TODO: something like is(Sink == scope void delegate(scope const(char)[]) @safe))
298     {
299 		if (this == 0)
300 			return sink("0");
301 		enum maxDigits = bits / 4;
302 		char[maxDigits] buf;
303 		IntImpl tmp = this;
304 		size_t i;
305 		for (i = maxDigits-1; tmp != 0 && i < buf.length; i--)
306 		{
307 			buf[i] = hexDigits[cast(int)tmp & 0b00001111];
308 			tmp >>= 4;
309 		}
310 		assert(i+1 < buf.length);
311 		sink(buf[i+1 .. $]);
312 	}
313 
314     void toStringDecimal(Sink)(Sink sink) const // TODO: something like is(Sink == scope void delegate(scope const(char)[]) @safe))
315 	{
316 		if (this == 0)
317 			return sink("0");
318 
319 		// The maximum number of decimal digits is basically
320 		// ceil(log_10(2^^bits - 1)), which is slightly below
321 		// ceil(bits * log(2)/log(10)). The value 0.30103 is a slight
322 		// overestimate of log(2)/log(10), to be sure we never
323 		// underestimate. We add 1 to account for rounding up.
324 		enum maxDigits = cast(ulong)(0.30103 * bits) + 1;
325 		char[maxDigits] buf;
326 		size_t i;
327 		Self q = void, r = void;
328 
329 		IntImpl tmp = this;
330 		if (tmp < 0)
331 		{
332 			sink("-");
333 			tmp = -tmp;
334 		}
335 		for (i = maxDigits-1; tmp > 0; i--)
336 		{
337 			assert(i < buf.length);
338 			static if (signed)
339 				Internals!bits.signedDivide(tmp, Self.literal!"10", q, r);
340 			else
341 				Internals!bits.unsignedDivide(tmp, Self.literal!"10", q, r);
342 
343 			buf[i] = digits[cast(int)(r)];
344 			tmp = q;
345 		}
346 		assert(i+1 < buf.length);
347 		sink(buf[i+1 .. $]);
348 	}
349 
350 	typeof(this) opBinary(string op, T)(T o) pure const nothrow @nogc
351     {
352 		typeof(return) r = this;
353         typeof(return) y = o;
354         return r.opOpAssign!(op)(y);
355     }
356 
357     ref typeof(this) opOpAssign(string op, T)(T y) pure nothrow @nogc if (!isSelf!T)
358     {
359         const(Self) o = y;
360         return opOpAssign!(op)(o); // TODO: this can be optimized
361     }
362 
363     ref typeof(this) opOpAssign(string op, T)(T y) pure nothrow @nogc if (isSelf!T)
364     {
365         static if (op == "+")
366         {
367             hi += y.hi;
368             if (lo + y.lo < lo) // deal with overflow
369                 ++hi;
370             lo += y.lo;
371         }
372         else static if (op == "-")
373         {
374             opOpAssign!"+"(-y);
375         }
376         else static if (op == "<<")
377         {
378             if (y >= bits)
379             {
380                 hi = 0;
381                 lo = 0;
382             }
383             else if (y >= bits / 2)
384             {
385                 hi = lo << (y.lo - bits / 2);
386                 lo = 0;
387             }
388             else if (y > 0)
389             {
390                 hi = (lo >>> (-y.lo + bits / 2)) | (hi << y.lo);
391                 lo = lo << y.lo;
392             }
393         }
394         else static if (op == ">>" || op == ">>>")
395         {
396             assert(y >= 0);
397             static if (!signed || op == ">>>")
398                 immutable(sub_int_t) signFill = 0;
399             else
400                 immutable(sub_int_t) signFill = cast(sub_int_t)(isNegative() ? -1 : 0);
401 
402             if (y >= bits)
403             {
404                 hi = signFill;
405                 lo = signFill;
406             }
407             else if (y >= bits/2)
408             {
409                 lo = hi >> (y.lo - bits/2);
410                 hi = signFill;
411             }
412             else if (y > 0)
413             {
414                 lo = (hi << (-y.lo + bits/2)) | (lo >> y.lo);
415                 hi = hi >> y.lo;
416             }
417         }
418         else static if (op == "*")
419         {
420             const sub_sub_uint_t[4] a = toParts();
421             const sub_sub_uint_t[4] b = y.toParts();
422 
423             this = 0;
424             foreach (const uint i; 0 .. 4)
425                 foreach (const uint j; 0 .. (4 - i))
426                     this += Self(cast(sub_uint_t)(a[i]) * b[j]) << ((bits/4) * (i + j));
427         }
428         else static if (op == "&")
429         {
430             hi &= y.hi;
431             lo &= y.lo;
432         }
433         else static if (op == "|")
434         {
435             hi |= y.hi;
436             lo |= y.lo;
437         }
438         else static if (op == "^")
439         {
440             hi ^= y.hi;
441             lo ^= y.lo;
442         }
443         else static if (op == "/" || op == "%")
444         {
445             Self q = void, r = void;
446             static if (signed)
447                 Internals!bits.signedDivide(this, y, q, r);
448             else
449                 Internals!bits.unsignedDivide(this, y, q, r);
450             static if (op == "/")
451                 this = q;
452             else
453                 this = r;
454         }
455         else
456         {
457             static assert(false, "unsupported operation '" ~ op ~ "'");
458         }
459         return this;
460     }
461 
462     // const unary operations
463     Self opUnary(string op)() pure const nothrow @nogc if (op == "+" || op == "-" || op == "~")
464     {
465         static if (op == "-")
466         {
467             Self r = this;
468             r.not();
469             r.increment();
470             return r;
471         }
472         else static if (op == "+")
473            return this;
474         else static if (op == "~")
475         {
476             Self r = this;
477             r.not();
478             return r;
479         }
480     }
481 
482     // non-const unary operations
483     Self opUnary(string op)() pure nothrow @nogc if (op == "++" || op == "--")
484     {
485         static if (op == "++")
486             increment();
487         else static if (op == "--")
488             decrement();
489         return this;
490     }
491 
492     bool opEquals(T)(in T y) pure const @nogc if (!isSelf!T) => this == Self(y);
493     bool opEquals(T)(in T y) pure const @nogc if (isSelf!T) => lo == y.lo && y.hi == hi;
494 
495     int opCmp(T)(in T y) pure const @nogc if (!isSelf!T) => opCmp(Self(y));
496     int opCmp(T)(in T y) pure const @nogc if (isSelf!T)
497     {
498         if (hi < y.hi) return -1;
499         if (hi > y.hi) return 1;
500         if (lo < y.lo) return -1;
501         if (lo > y.lo) return 1;
502         return 0;
503     }
504 
505     // binary layout should be what is expected on this platform
506     version (LittleEndian)
507     {
508         low_t lo;
509         hi_t hi;
510     }
511     else
512     {
513         hi_t hi;
514         low_t lo;
515     }
516 
517     private
518     {
519         static if (signed)
520             bool isNegative() @safe pure nothrow const @nogc => signBit();
521         else
522             bool isNegative() @safe pure nothrow const @nogc => false;
523 
524         void not() @safe pure nothrow @nogc
525         {
526             hi = ~hi;
527             lo = ~lo;
528         }
529 
530         void increment() @safe pure nothrow @nogc
531         {
532             ++lo;
533             if (lo == 0) ++hi;
534         }
535 
536         void decrement() @safe pure nothrow @nogc
537         {
538             if (lo == 0) --hi;
539             --lo;
540         }
541 
542 		enum SIGN_SHIFT = bits / 2 - 1;
543 
544         bool signBit() @safe pure const nothrow @nogc => ((hi >> SIGN_SHIFT) & 1) != 0;
545 
546         sub_sub_uint_t[4] toParts() @safe pure const nothrow @nogc
547         {
548             sub_sub_uint_t[4] p = void;
549             enum SHIFT = bits / 4;
550             immutable lomask = cast(sub_uint_t)(cast(sub_sub_int_t)(-1));
551             p[3] = cast(sub_sub_uint_t)(hi >> SHIFT);
552             p[2] = cast(sub_sub_uint_t)(hi & lomask);
553             p[1] = cast(sub_sub_uint_t)(lo >> SHIFT);
554             p[0] = cast(sub_sub_uint_t)(lo & lomask);
555             return p;
556         }
557     }
558 }
559 
560 template isWideIntInstantiation(U)
561 {
562     private static void isWideInt(bool signed, uint bits)(IntImpl!(signed, bits) x)
563     {
564     }
565 
566     enum bool isWideIntInstantiation = is(typeof(isWideInt(U.init)));
567 }
568 
569 public IntImpl!(signed, bits) abs(bool signed, uint bits)(IntImpl!(signed, bits) x) pure nothrow @nogc
570 	=> (x >= 0) ? x : -x;
571 
572 private struct Internals(uint bits)
573 {
574     alias wint_t = IntImpl!(true, bits);
575     alias uwint_t = IntImpl!(false, bits);
576 
577     static void unsignedDivide(uwint_t dividend, uwint_t divisor, out uwint_t quotient, out uwint_t remainder) pure nothrow @nogc
578     {
579         assert(divisor != 0);
580 
581         uwint_t rQuotient = 0;
582         uwint_t cDividend = dividend;
583 
584         while (divisor <= cDividend)
585         {
586             // find N so that (divisor << N) <= cDividend && cDividend < (divisor << (N + 1) )
587 
588             uwint_t N = 0;
589             uwint_t cDivisor = divisor;
590             while (cDividend > cDivisor)
591             {
592                 if (cDivisor.signBit())
593                     break;
594 
595                 if (cDividend < (cDivisor << 1))
596                     break;
597 
598                 cDivisor <<= 1;
599                 ++N;
600             }
601             cDividend = cDividend - cDivisor;
602             rQuotient += (uwint_t(1) << N);
603         }
604 
605         quotient = rQuotient;
606         remainder = cDividend;
607     }
608 
609     static void signedDivide(wint_t dividend, wint_t divisor, out wint_t quotient, out wint_t remainder) pure nothrow @nogc
610     {
611         uwint_t q, r;
612         unsignedDivide(uwint_t(abs(dividend)), uwint_t(abs(divisor)), q, r);
613 
614         // remainder has same sign as the dividend
615         if (dividend < 0)
616             r = -r;
617 
618         // negate the quotient if opposite signs
619         if ((dividend >= 0) != (divisor >= 0))
620             q = -q;
621 
622         quotient = q;
623         remainder = r;
624 
625         assert(remainder == 0 || ((remainder < 0) == (dividend < 0)));
626     }
627 }
628 
629 // Verify that toString is callable from pure / nothrow / @nogc code as long as
630 // the callback also has these attributes.
631 @safe unittest
632 {
633     int256 x = 123;
634     x.toStringDecimal((scope const(char)[]) @safe {});
635     x.toStringDecimal((scope const(char)[] x) @safe { assert(x == "123"); });
636 }
637 
638 version(format)
639 unittest
640 {
641     import std.format : format;
642 
643     int128 x;
644     x.hi = 1;
645     x.lo = 0x158E_4609_13D0_0001;
646     assert(format("%s", x) == "20000000000000000001");
647     assert(format("%d", x) == "20000000000000000001");
648     assert(format("%x", x) == "1158E460913D00001");
649 
650     x.hi = 0xFFFF_FFFF_FFFF_FFFE;
651     x.lo = 0xEA71_B9F6_EC2F_FFFF;
652     assert(format("%d", x) == "-20000000000000000001");
653     assert(format("%x", x) == "FFFFFFFFFFFFFFFEEA71B9F6EC2FFFFF");
654 
655     x.hi = x.lo = 0;
656     assert(format("%d", x) == "0");
657 
658     x.hi = x.lo = 0xFFFF_FFFF_FFFF_FFFF;
659     assert(format("%d", x) == "-1"); // array index boundary condition
660 }
661 
662 unittest
663 {
664 	string testSigned(string op) @safe pure nothrow
665 	{
666 		return "assert(cast(ulong)(si" ~ op ~ "sj) == cast(ulong)(csi" ~ op ~ "csj));";
667 	}
668 	string testMixed(string op) @safe pure nothrow
669 	{
670 		return "assert(cast(ulong)(ui" ~ op ~ "sj) == cast(ulong)(cui" ~ op ~ "csj));"
671 		~ "assert(cast(ulong)(si" ~ op ~ "uj) == cast(ulong)(csi" ~ op ~ "cuj));";
672 	}
673 	string testUnsigned(string op) @safe pure nothrow
674 	{
675 		return "assert(cast(ulong)(ui" ~ op ~ "uj) == cast(ulong)(cui" ~ op ~ "cuj));";
676 	}
677 	string testAll(string op) @safe pure nothrow
678 	{
679 		return testSigned(op) ~ testMixed(op) ~ testUnsigned(op);
680 	}
681     const long step = 164703072086692425;
682     for (long si = long.min; si <= long.max - step; si += step)
683     {
684         for (long sj = long.min; sj <= long.max - step; sj += step)
685         {
686             const ulong ui = cast(ulong)si;
687             const ulong uj = cast(ulong)sj;
688             int128 csi = si;
689             const uint128 cui = si;
690             const int128 csj = sj;
691             const uint128 cuj = sj;
692             assert(csi == csi);
693             assert(~~csi == csi);
694             assert(-(-csi) == csi);
695             assert(++csi == si + 1);
696             assert(--csi == si);
697 
698             mixin(testAll("+"));
699             mixin(testAll("-"));
700             mixin(testAll("*"));
701             mixin(testAll("|"));
702             mixin(testAll("&"));
703             mixin(testAll("^"));
704             if (sj != 0)
705             {
706                 mixin(testSigned("/"));
707                 mixin(testSigned("%"));
708                 if (si >= 0 && sj >= 0)
709                 {
710                     // those operations are not supposed to be the same at
711                     // higher bitdepth: a sign-extended negative may yield higher dividend
712                     testMixed("/");
713                     testUnsigned("/");
714                     testMixed("%");
715                     testUnsigned("%");
716                 }
717             }
718         }
719     }
720 }
721 
722 unittest
723 {
724     // Just a little over 2^64, so it actually needs int128.
725     // Hex value should be 0x1_158E_4609_13D0_0001.
726     enum x = int128.literal!"20_000_000_000_000_000_001";
727     assert(x.hi == 0x1 && x.lo == 0x158E_4609_13D0_0001);
728     assert((x >>> 1) == 0x8AC7_2304_89E8_0000);
729 
730     enum y = int128.literal!"0x1_158E_4609_13D0_0001";
731     enum z = int128.literal!"0x1_158e_4609_13d0_0001"; // case insensitivity
732     assert(x == y && y == z && x == z);
733 }
734 
735 unittest
736 {
737     version(format) import std.format : format;
738 
739     // Malformed literals that should be rejected
740     assert(!__traits(compiles, int128.literal!""));
741     assert(!__traits(compiles, int128.literal!"-"));
742 
743     // Negative literals should be supported
744     auto x = int128.literal!"-20000000000000000001";
745     assert(x.hi == 0xFFFF_FFFF_FFFF_FFFE &&
746            x.lo == 0xEA71_B9F6_EC2F_FFFF);
747     version(format) assert(format("%d", x) == "-20000000000000000001");
748     version(format) assert(format("%x", x) == "FFFFFFFFFFFFFFFEEA71B9F6EC2FFFFF");
749 
750     // Negative literals should not be supported for unsigned types
751     assert(!__traits(compiles, uint128.literal!"-1"));
752 
753     // Hex formatting tests
754     x = 0;
755     version(format) assert(format("%x", x) == "0");
756     x = -1;
757     version(format) assert(format("%x", x) == "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF");
758 }
759 
760 version(unittest)
761 {
762 	alias int128 = SInt!128;		// cent
763 	alias uint128 = UInt!128;		// ucent
764 	alias int256 = SInt!256;
765 	alias uint256 = UInt!256;
766 }