Skip to content

Commit a0fbc41

Browse files
committed
Add new test: computing \pi
The program pi is capable of computing the n^th decimal digit of pi with constant memory using only 32-bit integer arithmetic. Based on pi1.c by Fabrice Bellard, 1997. https://bellard.org/pi/
1 parent a666741 commit a0fbc41

File tree

2 files changed

+323
-0
lines changed

2 files changed

+323
-0
lines changed

build/pi.elf

91.7 KB
Binary file not shown.

tests/pi.c

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
/*
2+
* Computation of the n^th decimal digit of pi with constant memory using
3+
* only 32-bit integer arithmetic.
4+
*
5+
* This program is optimized by David McWilliams, 2021.
6+
* Based on pi1.c by Fabrice Bellard, 1997.
7+
* https://bellard.org/pi/
8+
*
9+
* Uses the hypergeometric series by Bill Gosper, 1974.
10+
* pi = sum( (50*n-6)/(binomial(3*n,n)*2^n), n=0..infinity )
11+
* https://arxiv.org/abs/math/0110238
12+
*
13+
* Uses the constant memory algorithm by Simon Plouffe, 1996.
14+
* https://arxiv.org/abs/0912.0303
15+
*
16+
* See also the faster n^th decimal digit program by Xavier Gourdon, 2003.
17+
* http://numbers.computation.free.fr/Constants/Algorithms/pidec.cpp
18+
*
19+
* To calculate the millionth digit of pi we need:
20+
* - Modulo multiplication that can handle base 2,654,253 without overflow.
21+
* - 6,505,391,993,984,718 main loops if all previous digits are calculated.
22+
* - 171,247,233,500 main loops if only the millionth digit is calculated.
23+
*/
24+
25+
#include <stdint.h>
26+
#include <stdio.h>
27+
28+
/* Modulo multiplication with 4 tick latency on input a,
29+
* and 6 tick latency on input b.
30+
* Input range: 0 <= a < 16777216 or 2^24
31+
* Input range: 0 <= b <= 2796202 or INT32_MAX/256/3
32+
* Input range: 0 <= m <= 2796202 or INT32_MAX/256/3
33+
* Output range: 0 <= result < m
34+
*/
35+
static inline int32_t mul_mod_21(int32_t a, int32_t b, int32_t m)
36+
{
37+
int32_t a1 = (a >> 0) & 0xFF;
38+
int32_t a2 = (a >> 8) & 0xFF;
39+
int32_t a3 = (a >> 16) & 0xFF;
40+
int32_t b2 = (b << 8) % m;
41+
int32_t b3 = (b2 << 8) % m;
42+
return (a1 * b + a2 * b2 + a3 * b3) % m;
43+
}
44+
45+
/* Modulo multiplication with 3 tick latency on input a,
46+
* and 7 tick latency on input b.
47+
* Input range: INT32_MIN <= a <= INT32_MAX
48+
* Input range: 0 <= b <= 4194304 or 2^32/256/4
49+
* Input range: 0 <= m <= 4194304 or 2^32/256/4
50+
* Output range: INT32_MIN <= result <= INT32_MAX
51+
*/
52+
static inline int32_t mul_mod_22(int32_t a, int32_t b, int32_t m)
53+
{
54+
int32_t a1 = (uint32_t) a >> 0 & 0xFF;
55+
int32_t a2 = (uint32_t) a >> 8 & 0xFF;
56+
int32_t a3 = (uint32_t) a >> 16 & 0xFF;
57+
int32_t a4 = (uint32_t) a >> 24 & 0xFF;
58+
int32_t b2 = (b << 8) % m;
59+
int32_t b3 = (b2 << 8) % m;
60+
int32_t b4 = (b3 << 8) % m;
61+
return a1 * b + a2 * b2 + a3 * b3 + a4 * b4;
62+
}
63+
64+
/* Modulo multiplication with 4 tick latency on input a,
65+
* and 7 tick latency on input b.
66+
* Input range: INT32_MIN <= a <= INT32_MAX
67+
* Input range: 0 <= b <= 8421504 or INT32_MAX/255
68+
* Input range: 0 <= m <= 8421504 or INT32_MAX/255
69+
* Output range: -m < result < 4*m
70+
*/
71+
static inline int32_t mul_mod_23(int32_t a, int32_t b, int32_t m)
72+
{
73+
int32_t a1 = (uint32_t) a >> 0 & 0xFF;
74+
int32_t a2 = (uint32_t) a >> 8 & 0xFF;
75+
int32_t a3 = (uint32_t) a >> 16 & 0xFF;
76+
int32_t a4 = (uint32_t) a >> 24 & 0xFF;
77+
int32_t b2 = (b << 8) % m;
78+
int32_t b3 = (b2 << 8) % m;
79+
int32_t b4 = (b3 << 8) % m;
80+
return a1 * b % m + a2 * b2 % m + a3 * b3 % m + a4 * b4 % m;
81+
}
82+
83+
/* Return a^b */
84+
int32_t powi(int32_t base, int32_t exp)
85+
{
86+
int32_t result = 1;
87+
while (exp) {
88+
if (exp & 1)
89+
result *= base;
90+
base *= base;
91+
exp >>= 1;
92+
}
93+
return result;
94+
}
95+
96+
/* Return (a^b) mod m */
97+
int32_t pow_mod(int32_t a, int32_t b, int32_t m)
98+
{
99+
int32_t result = 1;
100+
while (b > 0) {
101+
if (b & 1)
102+
result = mul_mod_21(result, a, m);
103+
a = mul_mod_21(a, a, m);
104+
b >>= 1;
105+
}
106+
return result;
107+
}
108+
109+
/* Solve for x: (a * x) % m == 1
110+
* https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Modular_integers
111+
*
112+
* N divisions is enough to calculate up to Fibonacci(N+3). Donald Knuth, 1981.
113+
* The Art of Computer Programming, Vol. 2: Seminumerical Algorithms, 2nd ed.
114+
* page 343.
115+
*
116+
* With 2 divisions per loop, 15 loops is enough to calculate up to 3500000.
117+
* Test case: 1346269 * 1346269 % 2178309 == 1
118+
*/
119+
int32_t inv_mod(int32_t a, int32_t m)
120+
{
121+
a %= m;
122+
int32_t b = m;
123+
int32_t x = 1;
124+
int32_t y = 0;
125+
for (int32_t i = 0; i < 15; i++) {
126+
int32_t q = (a == 0) ? 0 : b / a;
127+
b -= a * q;
128+
y -= x * q;
129+
q = (b == 0) ? 0 : a / b;
130+
a -= b * q;
131+
x -= y * q;
132+
}
133+
134+
return b ? (y + m) : x;
135+
}
136+
137+
/* Increment n until it is prime */
138+
int next_prime(int32_t n)
139+
{
140+
n++;
141+
142+
static uint32_t square_root = 0;
143+
if (square_root >= n)
144+
square_root = 0; /* reset cached value */
145+
146+
while (1) {
147+
while (square_root * square_root < n - 1)
148+
square_root++;
149+
150+
int32_t factors = 0;
151+
for (int32_t i = 2; i <= square_root; i++) {
152+
if (n % i == 0) {
153+
factors++;
154+
break;
155+
}
156+
}
157+
158+
if (factors <= 0) /* Found prime number */
159+
return n;
160+
161+
n++; /* found composite number */
162+
}
163+
}
164+
165+
/* Remove prime factors from n and count how many were removed */
166+
static int32_t prime_power[15], prime_power_count;
167+
int32_t factor_count(int32_t *n)
168+
{
169+
for (int32_t i = prime_power_count - 1; i >= 0; i--) {
170+
if (*n % prime_power[i] == 0) {
171+
*n /= prime_power[i];
172+
return i;
173+
}
174+
}
175+
__builtin_unreachable();
176+
}
177+
178+
/* Calculate sum = (sum + n/d) and store the decimal part in fixed-point format
179+
* with 18 decimal places across two 32-bit integers.
180+
*
181+
* This is equivalent to the floating point one-liner:
182+
* sum = fmod(sum + (double)n / (double)d, 1.0);
183+
*
184+
* Inputs must be <= 10,737,418 or INT32_MAX/200.
185+
*/
186+
void fixed_point_sum(int32_t n, int32_t d, int32_t *hi, int32_t *lo)
187+
{
188+
/* Digits 1 to 9 */
189+
int32_t n1 = n * 200;
190+
int32_t n2 = n1 % d * 200;
191+
int32_t n3 = n2 % d * 200;
192+
int32_t n4 = n3 % d * 125;
193+
*hi += n1 / d * 5000000;
194+
*hi += n2 / d * 25000;
195+
*hi += n3 / d * 125;
196+
*hi += n4 / d;
197+
198+
/* Digits 10 to 18 */
199+
int32_t n5 = n4 % d * 200;
200+
int32_t n6 = n5 % d * 200;
201+
int32_t n7 = n6 % d * 200;
202+
int32_t n8 = n7 % d * 125;
203+
*lo += n5 / d * 5000000;
204+
*lo += n6 / d * 25000;
205+
*lo += n7 / d * 125;
206+
*lo += n8 / d;
207+
208+
/* Carry */
209+
if (*lo > 1000000000)
210+
*hi += 1;
211+
212+
/* Discard overflow digits */
213+
*hi = *hi % 1000000000;
214+
*lo = *lo % 1000000000;
215+
}
216+
217+
/* Return 9 digits of pi */
218+
int32_t pifactory(int32_t start_digit)
219+
{
220+
int32_t sum = 0, sum_low = 0;
221+
222+
/* N = (start_digit + 19) / log10(13.5)
223+
* log10(13.5) is approximately equal to 269/238
224+
*/
225+
int32_t N = (start_digit + 19) * 238 / 269;
226+
227+
/* Compute the Gosper series modulo each prime power up to 3*N */
228+
for (int32_t prime = 2; prime < 3 * N; prime = next_prime(prime)) {
229+
/* Compute the first few prime powers
230+
* Only 15 powers are needed if start_digit < 1,000,000
231+
* Only powers up to 10,000,000 are needed if start_digit <= 1,000,000
232+
*/
233+
static const int32_t ROOT_10M[15] = {
234+
10000000, 10000000, 3162, 215, 56, 25, 14, 10, 7, 6, 5, 4, 3, 3, 3,
235+
};
236+
prime_power_count = 0;
237+
for (int32_t i = 0; i < 15; i++) {
238+
if (prime <= ROOT_10M[i]) {
239+
prime_power[i] = powi(prime, i);
240+
prime_power_count++;
241+
}
242+
}
243+
244+
/* For small primes, use a prime power with exponent greater than 1 */
245+
int32_t exponent = -1;
246+
for (int32_t i = 0; i < prime_power_count; i++) {
247+
if (prime_power[i] < 3 * N)
248+
exponent++;
249+
}
250+
int32_t m = powi(prime, exponent);
251+
252+
if (prime == 2) {
253+
/* Add the 2^N term in the denominator. */
254+
exponent += N - 1;
255+
/* We have some more powers of 2 in the 10^start_digit decimal shift
256+
* in the numerator. Use them to cancel out the 2^N term.
257+
*/
258+
m = powi(prime, exponent - start_digit);
259+
/* Since start_digit grows faster than N, eventually we will
260+
* cancel the entire exponent and m will become 0.
261+
*/
262+
if (m == 0)
263+
continue;
264+
}
265+
266+
/* Multiply by 10^start_digit to move the target digit to the most
267+
* significant decimal place.
268+
*/
269+
int32_t decimal = 10;
270+
if (prime == 2) /* We already used those powers of 2 */
271+
decimal = 5;
272+
int32_t decimal_shift = pow_mod(decimal, start_digit, m);
273+
274+
/* Main loop */
275+
int32_t subtotal = 0;
276+
int32_t numerator = 1;
277+
int32_t denominator = 1;
278+
for (int32_t k = 1; k <= N; k++) {
279+
/* Terms for the numerator */
280+
int32_t t1 = 2 * k, t2 = 2 * k - 1;
281+
exponent += factor_count(&t1);
282+
exponent += factor_count(&t2);
283+
int32_t terms = mul_mod_21(t1 % m, t2 % m, m);
284+
numerator = mul_mod_22(numerator, terms, m);
285+
286+
/* Terms for the denominator */
287+
int32_t t3 = 6 * k - 4, t4 = 9 * k - 3;
288+
exponent -= factor_count(&t3);
289+
exponent -= factor_count(&t4);
290+
terms = mul_mod_21(t3 % m, t4 % m, m);
291+
denominator = mul_mod_22(denominator, terms, m);
292+
293+
/* Multiply all parts together */
294+
int32_t inverse = inv_mod(denominator, m);
295+
int32_t t = (50 * k - 6) % m;
296+
t = mul_mod_23(numerator, t, m);
297+
t = mul_mod_21(t, powi(prime, exponent), m);
298+
t = mul_mod_21(t, inverse, m);
299+
300+
subtotal = (subtotal + t) % m;
301+
}
302+
subtotal = mul_mod_21(subtotal, decimal_shift, m);
303+
304+
/* We have a fraction over a prime power, add it to the final sum */
305+
fixed_point_sum(subtotal, m, &sum, &sum_low);
306+
}
307+
return sum;
308+
}
309+
310+
int32_t main()
311+
{
312+
int32_t start = 0, end = 100;
313+
314+
/* Print digits of pi */
315+
printf("3.");
316+
start++;
317+
318+
for (int32_t i = start - 1; i < end; i += 9)
319+
printf("%09d", pifactory(i));
320+
printf("\n");
321+
322+
return 0;
323+
}

0 commit comments

Comments
 (0)