summaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/numpy/random/_examples/numba/extending.py
diff options
context:
space:
mode:
authorblackhao <13851610112@163.com>2025-08-22 02:51:50 -0500
committerblackhao <13851610112@163.com>2025-08-22 02:51:50 -0500
commit4aab4087dc97906d0b9890035401175cdaab32d4 (patch)
tree4e2e9d88a711ec5b1cfa02e8ac72a55183b99123 /.venv/lib/python3.12/site-packages/numpy/random/_examples/numba/extending.py
parentafa8f50d1d21c721dabcb31ad244610946ab65a3 (diff)
2.0
Diffstat (limited to '.venv/lib/python3.12/site-packages/numpy/random/_examples/numba/extending.py')
-rw-r--r--.venv/lib/python3.12/site-packages/numpy/random/_examples/numba/extending.py86
1 files changed, 86 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/numpy/random/_examples/numba/extending.py b/.venv/lib/python3.12/site-packages/numpy/random/_examples/numba/extending.py
new file mode 100644
index 0000000..c1d0f4f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/numpy/random/_examples/numba/extending.py
@@ -0,0 +1,86 @@
+from timeit import timeit
+
+import numba as nb
+
+import numpy as np
+from numpy.random import PCG64
+
+bit_gen = PCG64()
+next_d = bit_gen.cffi.next_double
+state_addr = bit_gen.cffi.state_address
+
+def normals(n, state):
+ out = np.empty(n)
+ for i in range((n + 1) // 2):
+ x1 = 2.0 * next_d(state) - 1.0
+ x2 = 2.0 * next_d(state) - 1.0
+ r2 = x1 * x1 + x2 * x2
+ while r2 >= 1.0 or r2 == 0.0:
+ x1 = 2.0 * next_d(state) - 1.0
+ x2 = 2.0 * next_d(state) - 1.0
+ r2 = x1 * x1 + x2 * x2
+ f = np.sqrt(-2.0 * np.log(r2) / r2)
+ out[2 * i] = f * x1
+ if 2 * i + 1 < n:
+ out[2 * i + 1] = f * x2
+ return out
+
+
+# Compile using Numba
+normalsj = nb.jit(normals, nopython=True)
+# Must use state address not state with numba
+n = 10000
+
+def numbacall():
+ return normalsj(n, state_addr)
+
+
+rg = np.random.Generator(PCG64())
+
+def numpycall():
+ return rg.normal(size=n)
+
+
+# Check that the functions work
+r1 = numbacall()
+r2 = numpycall()
+assert r1.shape == (n,)
+assert r1.shape == r2.shape
+
+t1 = timeit(numbacall, number=1000)
+print(f'{t1:.2f} secs for {n} PCG64 (Numba/PCG64) gaussian randoms')
+t2 = timeit(numpycall, number=1000)
+print(f'{t2:.2f} secs for {n} PCG64 (NumPy/PCG64) gaussian randoms')
+
+# example 2
+
+next_u32 = bit_gen.ctypes.next_uint32
+ctypes_state = bit_gen.ctypes.state
+
+@nb.jit(nopython=True)
+def bounded_uint(lb, ub, state):
+ mask = delta = ub - lb
+ mask |= mask >> 1
+ mask |= mask >> 2
+ mask |= mask >> 4
+ mask |= mask >> 8
+ mask |= mask >> 16
+
+ val = next_u32(state) & mask
+ while val > delta:
+ val = next_u32(state) & mask
+
+ return lb + val
+
+
+print(bounded_uint(323, 2394691, ctypes_state.value))
+
+
+@nb.jit(nopython=True)
+def bounded_uints(lb, ub, n, state):
+ out = np.empty(n, dtype=np.uint32)
+ for i in range(n):
+ out[i] = bounded_uint(lb, ub, state)
+
+
+bounded_uints(323, 2394691, 10000000, ctypes_state.value)