Files

97 lines
2.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
The series, 1^1 + 2^2 + 3^3 + ... + 10^10 = 10405071317 .
Find the last ten digits of the series, 1^1 + 2^2 + 3^3 + ... + 1000^1000.
"""
import math
import time
from functools import lru_cache
def timer(func):
def wrapper(*args, **kwargs):
start_time = time.perf_counter()
result = func(*args, **kwargs)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
print(f"{func.__name__} time: {elapsed_time:.6f} seconds")
return result
return wrapper
@lru_cache(maxsize=128)
def _crt_params(n: int):
"""预计算并缓存 CRT 参数,避免重复计算"""
mod_2 = 2**n
mod_5 = 5**n
# 计算 5^n 模 2^n 的模逆元Python 3.8+ 支持负数指数求逆)
inv_5n_mod_2n = pow(mod_5, -1, mod_2)
return mod_2, mod_5, inv_5n_mod_2n
def fast_last_n_digits(base: int, exp: int, n: int) -> int:
"""
计算 base^exp 的最后 n 位数字(模 10^n
复杂度O(n * M(n/2)),其中 M(k) 是 k 位整数乘法的复杂度。
相比直接 pow(base, exp, 10**n) 的 O(log(exp) * M(n))
当 exp 极大且 n > 2 时通常快 2-4 倍。
"""
if n == 0:
return 0
if exp == 0:
return 1 % (10**n)
# 小 n 直接用内置 pow避免 CRT 开销)
if n <= 2:
return pow(base, exp, 10**n)
base = base % (10**n)
# 1. 指数缩减:利用 Carmichael 函数 λ(10^n) = 4·10^(n-1) (n≥2)
# 仅当 base 与 10 互质时适用
if math.gcd(base, 10) == 1:
lambda_n = 4 * (10 ** (n - 1))
if exp > lambda_n:
exp = exp % lambda_n
if exp == 0:
exp = lambda_n
# 2. CRT 分解:分别计算模 2^n 和模 5^n
mod_2, mod_5, inv = _crt_params(n)
# 优化模 2^n如果 base 是偶数且含有足够多的因子 2
if base & 1 == 0: # 偶数检查
# 快速计算 base 中因子 2 的个数(末尾 0 的个数)
k = (base & -base).bit_length() - 1
if k * exp >= n:
r2 = 0 # 2^n 整除 base^exp
else:
r2 = pow(base, exp, mod_2)
else:
r2 = pow(base, exp, mod_2)
# 计算模 5^n主要运算但操作数大小减半
r5 = pow(base, exp, mod_5)
# 3. CRT 合并Garner 算法):
# 求解 x ≡ r2 (mod 2^n), x ≡ r5 (mod 5^n)
# x = r5 + 5^n * ((r2 - r5) * inv(5^n, 2^n) mod 2^n)
t = ((r2 - r5) * inv) % mod_2
return r5 + mod_5 * t
@timer
def key_main(limit: int = 1000, last_n: int = 10):
res = []
for i in range(1, limit + 1):
res.append(fast_last_n_digits(i, i, last_n))
res = int("".join(list(str(sum(res)))[-last_n:]))
print(res)
if __name__ == "__main__":
key_main()