feat(tools): 新增题目时自动生成带计时装饰器的模板文件

This commit is contained in:
2026-02-19 14:32:07 +08:00
parent f42b8e6461
commit 4dae9d2e78
2 changed files with 115 additions and 1 deletions

View File

@@ -0,0 +1,96 @@
"""
The series, 1^1 + 2^2 + 3^3 + ... + 10^10 = 10405071317 .
Find the last ten digits of the series, 1^1 + 2^2 + 3^3 + ... + 1000^1000.
"""
import math
import time
from functools import lru_cache
def timer(func):
def wrapper(*args, **kwargs):
start_time = time.perf_counter()
result = func(*args, **kwargs)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
print(f"{func.__name__} time: {elapsed_time:.6f} seconds")
return result
return wrapper
@lru_cache(maxsize=128)
def _crt_params(n: int):
"""预计算并缓存 CRT 参数,避免重复计算"""
mod_2 = 2**n
mod_5 = 5**n
# 计算 5^n 模 2^n 的模逆元Python 3.8+ 支持负数指数求逆)
inv_5n_mod_2n = pow(mod_5, -1, mod_2)
return mod_2, mod_5, inv_5n_mod_2n
def fast_last_n_digits(base: int, exp: int, n: int) -> int:
"""
计算 base^exp 的最后 n 位数字(模 10^n
复杂度O(n * M(n/2)),其中 M(k) 是 k 位整数乘法的复杂度。
相比直接 pow(base, exp, 10**n) 的 O(log(exp) * M(n))
当 exp 极大且 n > 2 时通常快 2-4 倍。
"""
if n == 0:
return 0
if exp == 0:
return 1 % (10**n)
# 小 n 直接用内置 pow避免 CRT 开销)
if n <= 2:
return pow(base, exp, 10**n)
base = base % (10**n)
# 1. 指数缩减:利用 Carmichael 函数 λ(10^n) = 4·10^(n-1) (n≥2)
# 仅当 base 与 10 互质时适用
if math.gcd(base, 10) == 1:
lambda_n = 4 * (10 ** (n - 1))
if exp > lambda_n:
exp = exp % lambda_n
if exp == 0:
exp = lambda_n
# 2. CRT 分解:分别计算模 2^n 和模 5^n
mod_2, mod_5, inv = _crt_params(n)
# 优化模 2^n如果 base 是偶数且含有足够多的因子 2
if base & 1 == 0: # 偶数检查
# 快速计算 base 中因子 2 的个数(末尾 0 的个数)
k = (base & -base).bit_length() - 1
if k * exp >= n:
r2 = 0 # 2^n 整除 base^exp
else:
r2 = pow(base, exp, mod_2)
else:
r2 = pow(base, exp, mod_2)
# 计算模 5^n主要运算但操作数大小减半
r5 = pow(base, exp, mod_5)
# 3. CRT 合并Garner 算法):
# 求解 x ≡ r2 (mod 2^n), x ≡ r5 (mod 5^n)
# x = r5 + 5^n * ((r2 - r5) * inv(5^n, 2^n) mod 2^n)
t = ((r2 - r5) * inv) % mod_2
return r5 + mod_5 * t
@timer
def key_main(limit: int = 1000, last_n: int = 10):
res = []
for i in range(1, limit + 1):
res.append(fast_last_n_digits(i, i, last_n))
res = int("".join(list(str(sum(res)))[-last_n:]))
print(res)
if __name__ == "__main__":
key_main()