import math import time from functools import lru_cache from typing import Callable def timer(func: Callable) -> Callable: def wrapper(*args, **kwargs): start_time = time.perf_counter() result = func(*args, **kwargs) end_time = time.perf_counter() elapsed = end_time - start_time print(f"Function {func.__name__} execution time: {elapsed:.4f} seconds") return result return wrapper def maxpower(a: int, n: int) -> int: """计算最大的整数c,使得a^c ≤ n""" res = int(math.log(n) / math.log(a)) # 处理边界情况 if pow(a, res + 1) <= n: res += 1 if pow(a, res) > n: res -= 1 return res @lru_cache(maxsize=None) def lcm(a: int, b: int) -> int: """计算最小公倍数(使用缓存优化)""" gcd_val = math.gcd(a, b) return a // gcd_val * b def recurse( lc: int, index: int, sign: int, left: int, right: int, thelist: list[int] ) -> int: """容斥原理的递归实现""" if lc > right: return 0 res = sign * (right // lc - (left - 1) // lc) # 递归处理剩余元素 for i in range(index + 1, len(thelist)): res += recurse(lcm(lc, thelist[i]), i, -sign, left, right, thelist) return res def dd(left: int, right: int, a: int, b: int, check: list[bool]) -> int: """双层容斥计算""" res = right // b - (left - 1) // b thelist = [i for i in range(a, b) if check[i]] for i in range(len(thelist)): res -= recurse(lcm(b, thelist[i]), i, 1, left, right, thelist) return res def compute_counts(n: int) -> tuple[list[int], int, int]: """计算前缀和数组""" sqn = int(math.isqrt(n)) # 使用isqrt替代sqrt,返回整数 maxc = maxpower(2, n) # 初始化数组 counts = [0] * (maxc + 1) counts[1] = n - 1 # 主计算循环 for c in range(2, maxc + 1): check = [True] * (maxc + 1) umin = (c - 1) * n + 1 umax = c * n # 优化筛法:使用步长跳过非倍数 for i in range(c, maxc // 2 + 1): check[i * 2 : maxc + 1 : i] = [False] * ((maxc - i * 2) // i + 1) # 只处理质数(check[f]为True) for f in range(c, maxc + 1): if check[f]: counts[f] += dd(umin, umax, c, f, check) # 计算前缀和 for c in range(2, maxc + 1): counts[c] += counts[c - 1] return counts, sqn, maxc def compute_final_answer(n: int) -> int: """计算最终答案""" counts, sqn, _ = compute_counts(n) ans = 0 coll = 0 used = [False] * (sqn + 1) # 统计答案 for i in range(2, sqn + 1): if not used[i]: c = maxpower(i, n) ans += counts[c] u = i for j in range(2, c + 1): u *= i if u <= sqn: used[u] = True else: coll += c - j + 1 break # 最终调整 ans += (n - sqn) * (n - 1) ans -= coll * (n - 1) return ans @timer def main(n: int = 10**6) -> None: answer = compute_final_answer(n) print(f"n = {n}, Answer = {answer}") if __name__ == "__main__": main(10**7)