23 Commits

Author SHA1 Message Date
459670f363 fix(llm): 修复 Kimi API temperature 参数配置 2026-06-01 17:48:42 +08:00
7636d0b5a6 feat(llm): 统一思考模式配置,支持显式禁用状态 2026-06-01 17:39:36 +08:00
928ebb61b4 refactor(llm): renumber system prompt rules 2026-05-27 15:37:51 +08:00
7e85cdd8b0 chore(release): 升级版本号至 0.3.0 2026-05-27 15:16:26 +08:00
90074e6e32 style: 格式化代码并优化导入顺序 2026-05-27 15:15:15 +08:00
b8182e7538 修复kimi返回信息的读取错误 2026-05-27 14:50:47 +08:00
4331b9306e LLM支持优化 2026-05-26 17:43:42 +08:00
a08bc809bb 修复bug 2026-05-26 16:30:28 +08:00
1063369d96 feat(deepseek): 添加 DeepSeek reasoning 模式支持 2026-05-26 16:27:49 +08:00
3a57d25a76 docs: 添加 QuiCommit 项目路线图文档 2026-05-14 17:04:13 +08:00
8152edba39 chore: 删除构建输出日志文件 2026-05-13 13:54:50 +08:00
679db5b1db chore: 清理大量未使用的变量、方法及结构体警告 2026-05-13 13:54:20 +08:00
b1ad68c7b5 build: 升级版本号至 0.2.0 2026-05-13 12:08:03 +08:00
280d6ec5c9 feat(generator): 按文件重要性对暂存差异排序 2026-05-13 12:07:01 +08:00
68427c4a11 chore: 发布 v0.1.11 并更新文档 2026-03-23 18:07:11 +08:00
8dd9e85b77 feat(config): 在加密导出/导入中包含个人访问令牌 2026-03-23 17:59:23 +08:00
0c7d2ad518 refactor: 移除未使用的代码和注释掉的辅助函数 2026-03-20 18:05:33 +08:00
0289dd4684 chore: 升级版本至 0.1.10 并更新密钥环与加密相关描述 2026-03-19 16:34:45 +08:00
e2d43315e3 feat: 新增系统密钥环安全存储API密钥与自动生成变更日志功能 2026-03-12 18:05:38 +08:00
0e1c2c6350 chore: 移除测试 keyring 功能的临时文件 2026-03-12 17:45:05 +08:00
da85fc94b1 feat(keyring): 集成系统密钥环安全存储 API key 2026-03-12 17:42:41 +08:00
c66d782eab chore: 更新版本号至 0.1.9 并补充 changelog 2026-03-06 16:31:46 +08:00
358b44ab81 fix(generator): 修复diff截断时的字符边界问题 2026-03-06 16:28:26 +08:00
41 changed files with 7352 additions and 2851 deletions

2
.gitignore vendored
View File

@@ -21,3 +21,5 @@ test_output/
# Config (for development)
config.toml
.claude/
CLAUDE.md

View File

@@ -5,6 +5,55 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
暂无。
## [0.3.1] - 2026-06-01
### ✨ 新功能
- 按文件重要性对暂存差异排序,优先处理核心变更
- DeepSeek 新增 reasoning 推理模式支持
- LLM 统一思考模式配置,支持显式启用/禁用思考状态
- 新增 `thinking.rs` 思考状态管理模块
### 🐞 错误修复
- 修复 Kimi 返回信息的读取错误
- 修复 DeepSeek 和 Kimi 流式响应的解析问题
### 📚 文档
- 新增 ROADMAP.md 项目路线图文档
### 🔧 其他变更
- LLM 模块大规模重构所有提供商Anthropic、DeepSeek、Kimi、Ollama、OpenAI、OpenRouter适配流式响应处理
- 代码格式化并优化导入顺序
- 清理大量未使用的变量、方法及结构体警告
- 清理构建输出日志文件
- 重新编号 LLM 系统提示规则
- i18n 多语言消息格式修复
- 各命令模块commit、tag、changelog、config、profile、init持续优化
## [0.1.11] - 2026-03-23
### ✨ 新功能
- 新增配置导出导入功能,支持加密保护
- Profile 支持 Token 管理PAT 等)
- 自动生成和维护 Keep a Changelog 格式的变更日志
- 交互式命令行界面,支持预览和确认
### 🔐 安全特性
- 敏感数据加密存储API 密钥等)
- 使用系统密钥环安全保存凭证
### 🔧 其他变更
- 优化 diff 截断逻辑,使用字符边界确保多字节字符安全
- 改进配置管理器,支持修改追踪
## [0.1.9] - 2026-03-06
### 🐞 错误修复
- 修复diff截断时的字符边界问题
## [0.1.7] - 2026-02-14
### 🐞 错误修复

View File

@@ -1,6 +1,6 @@
[package]
name = "quicommit"
version = "0.1.8"
version = "0.3.1"
edition = "2024"
authors = ["Sidney Zhang <zly@lyzhang.me>"]
description = "A powerful Git assistant tool with AI-powered commit/tag/changelog generation(alpha version)"
@@ -33,7 +33,7 @@ git2 = "0.20.3"
which = "6.0"
# HTTP client for LLM APIs
reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false }
reqwest = { version = "0.12", features = ["json", "rustls-tls", "stream"], default-features = false }
tokio = { version = "1.35", features = ["full"] }
# Error handling
@@ -57,6 +57,7 @@ sha2 = "0.10"
hex = "0.4"
textwrap = "0.16"
async-trait = "0.1"
futures-util = "0.3"
serde_json = "1.0"
atty = "0.2"
@@ -66,6 +67,9 @@ argon2 = "0.5"
rand = "0.8"
base64 = "0.22"
# System keyring for secure API key storage
keyring = { version = "3.6.3", features = ["apple-native", "windows-native", "sync-secret-service"] }
# Interactive editor
edit = "0.1"
@@ -80,11 +84,13 @@ mockall = "0.12"
wiremock = "0.6"
[profile.release]
opt-level = 3
lto = true
opt-level = "s"
lto = "thin"
codegen-units = 1
panic = "abort"
strip = true
debug = false
[profile.dev]
opt-level = 0
opt-level = 1
debug = true

128
RAODMAP.md Normal file
View File

@@ -0,0 +1,128 @@
# QuiCommit Roadmap
## 已完成 ✅
- [x] 基础 Git 操作commit、tag、changelog
- [x] AI 驱动提交信息生成Conventional Commits / commitlint 格式)
- [x] 多 LLM 提供商支持Ollama、OpenAI、Anthropic、Kimi、DeepSeek、OpenRouter
- [x] 多 Git Profile 管理SSH 密钥 + GPG 签名)
- [x] 语义化版本自动升级与 AI 发布说明
- [x] Keep a Changelog 格式自动生成
- [x] 系统密钥环安全存储 API Key
- [x] 敏感数据加密存储AES-GCM + Argon2
- [x] 交互式 CLI 预览与确认
- [x] 7 种语言国际化支持
- [x] 配置导出/导入(支持加密保护)
- [x] Profile Token 管理PAT 等)
---
## 进行中 🚧
暂无。
---
## 计划中 📋
### 1. Git 凭证管理器
将 Git 凭证管理集成到 QuiCommit 中,统一管理 HTTPS 仓库的身份认证。
- [ ] **Git Credential Helper 集成**
- 实现 `git credential-store` / `git-credential-libsecret` 等标准的 credential helper 协议
- 支持 `quicommit credential get|store|erase` 子命令
- 与系统密钥环无缝对接,复用已有的 `KeyringManager`
- [ ] **凭证管理 CLI**
- `quicommit credential list` — 列出所有已存储的凭证
- `quicommit credential add` — 手动添加凭证(用户名 + 密码/Token
- `quicommit credential remove` — 删除指定凭证
- `quicommit credential status` — 查看凭证管理状态
- [ ] **跨平台支持**
- Windows集成 Windows Credential Manager
- macOS集成 Keychain
- Linux通过 Secret Service / D-Bus 对接 GNOME Keyring / KWallet
- [ ] **安全增强**
- 支持 PATPersonal Access Token按 scope / 有效期管理
- 支持凭证过期检查和自动提醒
---
### 2. 新增模型支持
扩展 LLM 提供商和模型覆盖范围,满足更多场景和偏好。
- [x] **新增 DeepSeek 最新模型**
- 支持 `deepseek-chat`DeepSeek-V3
- 支持 `deepseek-reasoner`DeepSeek-R1
- 支持 `deepseek-v4`
- [ ] **新增国内模型提供商**
- 通义千问 (Qwen) — 阿里云 DashScope API
- 文心一言 (ERNIE) — 百度千帆 API
- 智谱 GLM — ChatGLM API
- 百川 (Baichuan) — Baichuan API
- [ ] **新增国际模型提供商**
- Google Gemini API
- Mistral AI API
- Cohere API
- Groq (LPU 推理加速)
- [ ] **本地模型扩展**
- 支持 llama.cpp 服务端(兼容 OpenAI API 格式)
- 支持 vLLM 部署的模型
- 本地模型推荐列表与一键配置向导
- [ ] **模型能力适配**
- 不同模型的 token 限制自适应
- 模型特定的 prompt 模板优化
- 支持 function calling / tool use用于复杂生成场景
---
### 3. 生成体验优化
提升 AI 生成提交信息、标签说明和变更日志时的用户体验。
- [x] **流式输出与实时反馈**
- 支持 SSEServer-Sent Events流式生成
- 终端打字机效果实时显示生成内容
- 流式生成过程中支持 `Ctrl+C` 中断
- [ ] **生成质量提升**
- 基于 commitlint 规则的后校验与自动修正
- 支持 Few-shot 示例引导(用户可自定义示例库)
- 生成结果的置信度评分与多候选方案
- [ ] **Diff 上下文增强**
- 智能 diff 摘要(大改动时自动压缩关键信息)
- 支持 `.gitattributes` 排除/包含规则
- 按文件类型分组生成更精准的提交描述
- [ ] **交互式编辑增强**
- 生成后支持内联编辑(类似 `git rebase -i` 体验)
- 支持重新生成指定部分(如 scope、description
- 历史提交信息学习与风格适配
- [ ] **批量操作支持**
- 批量生成多个 commit分组暂存区变更
- `--dry-run` 预览模式(只生成本地查看,不写 Git
- [ ] **性能优化**
- API 请求并发优化(多个模型同时生成候选)
- 本地缓存常用 prompt 模板
- 减少不必要的 diff 计算
---
## 长远规划 🌟
- [ ] **VS Code 扩展** — 在 IDE 内直接使用 QuiCommit
- [ ] **GitHub Action / GitLab CI 集成** — 自动化 PR 标题和描述生成
- [ ] **团队协作** — 共享 commit 风格配置、prompt 模板库
- [ ] **Web Dashboard** — 可视化管理多仓库的 Git 活动与 AI 生成统计
- [ ] **插件系统** — 允许社区贡献自定义 LLM 提供商和生成策略

View File

@@ -6,8 +6,12 @@ A powerful AI-powered Git assistant for generating conventional commits, tags, a
[Still in early development, some features may not be complete. Feedback and contributions are welcome.]
> ⚠️ **Important Notice**: QuiCommit now uses system keyring to store API keys securely. This change may cause breaking changes to your existing configuration. If you encounter issues after updating, please run `quicommit config reset --force` to reset your configuration, then reconfigure your settings.
![Rust](https://img.shields.io/badge/rust-%23000000.svg?style=for-the-badge&logo=rust&logoColor=white)
![License](https://img.shields.io/badge/license-MIT-blue.svg)
![Crates.io Version](https://img.shields.io/crates/v/quicommit)
## Features
@@ -16,8 +20,10 @@ A powerful AI-powered Git assistant for generating conventional commits, tags, a
- **Profile Management**: Manage multiple Git identities with SSH keys and GPG signing support
- **Smart Tagging**: Semantic version bumping with AI-generated release notes
- **Changelog Generation**: Automatic changelog generation in Keep a Changelog format
- **Security**: Encrypt sensitive data
- **Security**: Use system keyring to store API keys securely
- **Interactive UI**: Beautiful CLI with previews and confirmations
- **Multi-language Support**: Output in 7 languages (English, Chinese, Japanese, Korean, Spanish, French, German)
- **Config Export/Import**: Backup and restore configuration with optional encryption
## Installation
@@ -159,30 +165,30 @@ quicommit profile token
```bash
# Configure Ollama (local)
quicommit config set-llm ollama
quicommit config set-ollama --url http://localhost:11434 --model llama2
quicommit config set-llm ollama --url http://localhost:11434 --model llama2
# Configure OpenAI
quicommit config set-llm openai
quicommit config set-openai-key YOUR_API_KEY
quicommit config set-api-key YOUR_API_KEY
# Configure Anthropic Claude
quicommit config set-llm anthropic
quicommit config set-anthropic-key YOUR_API_KEY
quicommit config set-api-key YOUR_API_KEY
# Configure Kimi (Moonshot AI)
quicommit config set-llm kimi
quicommit config set-kimi-key YOUR_API_KEY
quicommit config set-kimi --base-url https://api.moonshot.cn/v1 --model moonshot-v1-8k
quicommit config set-api-key YOUR_API_KEY
quicommit config set-llm kimi --base-url https://api.moonshot.cn/v1 --model moonshot-v1-8k
# Configure DeepSeek
quicommit config set-llm deepseek
quicommit config set-deepseek-key YOUR_API_KEY
quicommit config set-deepseek --base-url https://api.deepseek.com/v1 --model deepseek-chat
quicommit config set-api-key YOUR_API_KEY
quicommit config set-llm deepseek --base-url https://api.deepseek.com/v1 --model deepseek-chat
# Configure OpenRouter
quicommit config set-llm openrouter
quicommit config set-openrouter-key YOUR_API_KEY
quicommit config set-openrouter --base-url https://openrouter.ai/api/v1 --model openai/gpt-4
quicommit config set-api-key YOUR_API_KEY
quicommit config set-llm openrouter --base-url https://openrouter.ai/api/v1 --model openai/gpt-4
# Set commit format
quicommit config set-commit-format conventional
@@ -193,7 +199,7 @@ quicommit config set-version-prefix v
# Set changelog path
quicommit config set-changelog-path CHANGELOG.md
# Set output language
# Set output language (en, zh, ja, ko, es, fr, de)
quicommit config set-language en
# Set keep commit types in English
@@ -205,8 +211,22 @@ quicommit config set-keep-changelog-types-english true
# Test LLM connection
quicommit config test-llm
# Check keyring availability
quicommit config check-keyring
# Show config file path
quicommit config path
# Export configuration (with optional encryption)
quicommit config export -o config-backup.toml
quicommit config export -o config-backup.enc --password
# Import configuration
quicommit config import -i config-backup.toml
quicommit config import -i config-backup.enc --password
# Reset configuration to defaults
quicommit config reset
quicommit config reset --force
```
## Command Reference
@@ -396,17 +416,31 @@ quicommit config set llm.provider ollama
# Get configuration value
quicommit config get llm.provider
# Set API key (stored in system keyring)
quicommit config set-api-key YOUR_API_KEY
# Delete API key from keyring
quicommit config delete-api-key
# Test LLM connection
quicommit config test-llm
# List available models
quicommit config list-models
# Export configuration
# Check keyring availability
quicommit config check-keyring
# Show config file path
quicommit config path
# Export configuration (with optional encryption)
quicommit config export -o config-backup.toml
quicommit config export -o config-backup.enc --password
# Import configuration
quicommit config import -i config-backup.toml
quicommit config import -i config-backup.enc --password
# Reset configuration
quicommit config reset --force

View File

@@ -1,12 +1,13 @@
use std::env;
fn main() {
// Only generate completions when explicitly requested
if env::var("GENERATE_COMPLETIONS").is_ok() {
println!("cargo:warning=To generate shell completions, run: cargo run --bin quicommit -- completions");
println!(
"cargo:warning=To generate shell completions, run: cargo run --bin quicommit -- completions"
);
}
// Rerun if build.rs changes
println!("cargo:rerun-if-changed=build.rs");
}

View File

@@ -4,6 +4,13 @@
# - macOS: ~/Library/Application Support/quicommit/config.toml
# - Windows: %APPDATA%\quicommit\config.toml
# ⚠️ IMPORTANT: Keyring Feature Update
# QuiCommit now uses system keyring to store API keys securely.
# This change may cause breaking changes to your existing configuration.
# If you encounter issues after updating, please reset your configuration:
# quicommit config reset --force
# Then reconfigure your settings using the CLI commands.
# Configuration version (for migration)
version = "1"

View File

@@ -6,8 +6,11 @@
【目前还处在早期开发阶段,依然有一些功能未完善,欢迎反馈和贡献。】
> ⚠️ **重要提示**QuiCommit 现在使用系统密钥环keyring来安全存储 API 密钥。此更改可能会对现有配置造成破坏性变更。如果在更新后遇到问题,请运行 `quicommit config reset --force` 重置配置,然后重新配置您的设置。
![Rust](https://img.shields.io/badge/rust-%23000000.svg?style=for-the-badge&logo=rust&logoColor=white)
![License](https://img.shields.io/badge/license-MIT-blue.svg)
![Crates.io Version](https://img.shields.io/crates/v/quicommit)
## 主要功能
@@ -16,8 +19,10 @@
- **多配置管理**为不同场景管理多个Git身份支持SSH密钥和GPG签名配置
- **智能标签管理**基于语义版本自动检测升级AI生成标签信息
- **变更日志生成**自动生成Keep a Changelog格式的变更日志
- **安全保护**加密存储敏感数据
- **安全保护**使用系统密钥环进行安全存储
- **交互式界面**美观的CLI界面支持预览和确认
- **多语言支持**支持7种语言输出中文、英语、日语、韩语、西班牙语、法语、德语
- **配置导出导入**:备份和恢复配置,支持加密保护
## 安装
@@ -159,30 +164,30 @@ quicommit profile token
```bash
# 配置Ollama本地
quicommit config set-llm ollama
quicommit config set-ollama --url http://localhost:11434 --model llama2
quicommit config set-llm ollama --url http://localhost:11434 --model llama2
# 配置OpenAI
quicommit config set-llm openai
quicommit config set-openai-key YOUR_API_KEY
quicommit config set-api-key YOUR_API_KEY
# 配置Anthropic Claude
quicommit config set-llm anthropic
quicommit config set-anthropic-key YOUR_API_KEY
quicommit config set-api-key YOUR_API_KEY
# 配置Kimi
quicommit config set-llm kimi
quicommit config set-kimi-key YOUR_API_KEY
quicommit config set-kimi --base-url https://api.moonshot.cn/v1 --model moonshot-v1-8k
quicommit config set-api-key YOUR_API_KEY
quicommit config set-llm kimi --base-url https://api.moonshot.cn/v1 --model moonshot-v1-8k
# 配置DeepSeek
quicommit config set-llm deepseek
quicommit config set-deepseek-key YOUR_API_KEY
quicommit config set-deepseek --base-url https://api.deepseek.com/v1 --model deepseek-chat
quicommit config set-api-key YOUR_API_KEY
quicommit config set-llm deepseek --base-url https://api.deepseek.com/v1 --model deepseek-chat
# 配置OpenRouter
quicommit config set-llm openrouter
quicommit config set-openrouter-key YOUR_API_KEY
quicommit config set-openrouter --base-url https://openrouter.ai/api/v1 --model openai/gpt-4
quicommit config set-api-key YOUR_API_KEY
quicommit config set-llm openrouter --base-url https://openrouter.ai/api/v1 --model openai/gpt-4
# 设置提交格式
quicommit config set-commit-format conventional
@@ -193,7 +198,7 @@ quicommit config set-version-prefix v
# 设置变更日志路径
quicommit config set-changelog-path CHANGELOG.md
# 设置输出语言
# 设置输出语言zh, en, ja, ko, es, fr, de
quicommit config set-language zh
# 设置保持提交类型为英文
@@ -205,8 +210,22 @@ quicommit config set-keep-changelog-types-english true
# 测试LLM连接
quicommit config test-llm
# 检查密钥环可用性
quicommit config check-keyring
# 显示配置文件路径
quicommit config path
# 导出配置(支持加密)
quicommit config export -o config-backup.toml
quicommit config export -o config-backup.enc --password
# 导入配置
quicommit config import -i config-backup.toml
quicommit config import -i config-backup.enc --password
# 重置配置为默认值
quicommit config reset
quicommit config reset --force
```
## 命令参考
@@ -396,17 +415,31 @@ quicommit config set llm.provider ollama
# 获取配置值
quicommit config get llm.provider
# 设置API密钥存储在系统密钥环中
quicommit config set-api-key YOUR_API_KEY
# 从密钥环删除API密钥
quicommit config delete-api-key
# 测试LLM连接
quicommit config test-llm
# 列出可用模型
quicommit config list-models
# 导出配置
# 检查密钥环可用性
quicommit config check-keyring
# 显示配置文件路径
quicommit config path
# 导出配置(支持加密)
quicommit config export -o config-backup.toml
quicommit config export -o config-backup.enc --password
# 导入配置
quicommit config import -i config-backup.toml
quicommit config import -i config-backup.enc --password
# 重置配置
quicommit config reset --force

View File

@@ -1,4 +1,4 @@
use anyhow::{bail, Result};
use anyhow::{Result, bail};
use chrono::Utc;
use clap::Parser;
use colored::Colorize;
@@ -8,7 +8,7 @@ use std::path::PathBuf;
use crate::config::{Language, manager::ConfigManager};
use crate::generator::ContentGenerator;
use crate::git::find_repo;
use crate::git::{changelog::*, CommitInfo};
use crate::git::{CommitInfo, changelog::*};
use crate::i18n::{Messages, translate_changelog_category};
/// Generate changelog
@@ -55,6 +55,10 @@ pub struct ChangelogCommand {
#[arg(long)]
dry_run: bool,
/// Enable thinking mode for this changelog (override config)
#[arg(long)]
think: bool,
/// Skip interactive prompts
#[arg(short = 'y', long)]
yes: bool,
@@ -74,18 +78,20 @@ impl ChangelogCommand {
// Initialize changelog if requested
if self.init {
let path = self.output.as_ref()
.map(|p| p.clone())
let path = self
.output
.clone()
.unwrap_or_else(|| PathBuf::from(&config.changelog.path));
init_changelog(&path)?;
println!("{}", messages.initialized_changelog(&format!("{:?}", path)));
return Ok(());
}
// Determine output path
let output_path = self.output.as_ref()
.map(|p| p.clone())
let output_path = self
.output
.clone()
.unwrap_or_else(|| PathBuf::from(&config.changelog.path));
// Determine format
@@ -94,7 +100,10 @@ impl ChangelogCommand {
Some("keep") | Some("keep-a-changelog") => ChangelogFormat::KeepAChangelog,
Some("custom") => ChangelogFormat::Custom,
None => ChangelogFormat::KeepAChangelog,
Some(f) => bail!("Unknown format: {}. Use: keep-a-changelog, github-releases", f),
Some(f) => bail!(
"Unknown format: {}. Use: keep-a-changelog, github-releases",
f
),
};
// Get version
@@ -112,11 +121,11 @@ impl ChangelogCommand {
// Get commits
println!("{}", messages.fetching_commits());
let commits = generate_from_history(&repo, self.from.as_deref(), Some(&self.to))?;
if commits.is_empty() {
bail!("{}", messages.no_commits_found());
}
println!("{}", messages.found_commits(commits.len()));
// Generate changelog
@@ -148,7 +157,7 @@ impl ChangelogCommand {
println!("{}", "".repeat(60));
let confirm = Confirm::new()
.with_prompt(&messages.write_to_file(&format!("{:?}", output_path)))
.with_prompt(messages.write_to_file(&format!("{:?}", output_path)))
.default(true)
.interact()?;
@@ -168,7 +177,7 @@ impl ChangelogCommand {
} else if existing.starts_with("# Changelog") {
let lines: Vec<&str> = existing.lines().collect();
let mut header_end = 0;
for (i, line) in lines.iter().enumerate() {
if i == 0 && line.starts_with('#') {
header_end = i + 1;
@@ -178,10 +187,10 @@ impl ChangelogCommand {
break;
}
}
let header = lines[..header_end].join("\n");
let rest = lines[header_end..].join("\n");
format!("{}\n{}\n{}", header, changelog, rest)
} else {
format!("{}{}", CHANGELOG_HEADER, changelog)
@@ -204,13 +213,14 @@ impl ChangelogCommand {
messages: &Messages,
) -> Result<String> {
let manager = ConfigManager::new()?;
let config = manager.config();
let language = manager.get_language().unwrap_or(Language::English);
println!("{}", messages.ai_generating_changelog());
let generator = ContentGenerator::new(&config.llm).await?;
generator.generate_changelog_entry(version, commits, language).await
let generator = ContentGenerator::new_with_think(&manager, self.think).await?;
generator
.generate_changelog_entry(version, commits, language)
.await
}
fn generate_with_template(
@@ -221,14 +231,14 @@ impl ChangelogCommand {
language: Language,
) -> Result<String> {
let manager = ConfigManager::new()?;
let generator = ChangelogGenerator::new()
.format(format)
.include_hashes(self.include_hashes)
.include_authors(self.include_authors);
let changelog = generator.generate(version, Utc::now(), commits)?;
// Translate changelog categories if configured
if !manager.keep_changelog_types_english() {
Ok(self.translate_changelog_categories(&changelog, language))
@@ -236,14 +246,15 @@ impl ChangelogCommand {
Ok(changelog)
}
}
fn translate_changelog_categories(&self, changelog: &str, language: Language) -> String {
let translated = changelog
changelog
.lines()
.map(|line| {
if line.starts_with("## ") || line.starts_with("### ") {
let category = line.trim_start_matches("## ").trim_start_matches("### ");
let translated_category = translate_changelog_category(category, language, false);
let translated_category =
translate_changelog_category(category, language, false);
if line.starts_with("## ") {
format!("## {}", translated_category)
} else {
@@ -254,7 +265,6 @@ impl ChangelogCommand {
}
})
.collect::<Vec<_>>()
.join("\n");
translated
.join("\n")
}
}

View File

@@ -1,14 +1,14 @@
use anyhow::{bail, Context, Result};
use anyhow::{Context, Result, bail};
use clap::Parser;
use colored::Colorize;
use dialoguer::{Confirm, Input, Select};
use std::path::PathBuf;
use crate::config::{Language, manager::ConfigManager};
use crate::config::CommitFormat;
use crate::config::{Language, manager::ConfigManager};
use crate::generator::ContentGenerator;
use crate::git::{find_repo, GitRepo};
use crate::git::commit::{CommitBuilder, create_date_commit_message};
use crate::git::{GitRepo, find_repo};
use crate::i18n::Messages;
use crate::utils::validators::get_commit_types;
@@ -71,6 +71,10 @@ pub struct CommitCommand {
#[arg(long)]
no_verify: bool,
/// Enable thinking mode for this commit (override config)
#[arg(short = 't', long)]
think: bool,
/// Skip interactive prompts
#[arg(short = 'y', long)]
yes: bool,
@@ -88,7 +92,7 @@ impl CommitCommand {
pub async fn execute(&self, config_path: Option<PathBuf>) -> Result<()> {
// Find git repository
let repo = find_repo(std::env::current_dir()?.as_path())?;
// Load configuration
let manager = if let Some(ref path) = config_path {
ConfigManager::with_path(path)?
@@ -98,7 +102,7 @@ impl CommitCommand {
let config = manager.config();
let language = manager.get_language().unwrap_or(Language::English);
let messages = Messages::new(language);
// Check for changes
let status = repo.status_summary()?;
if status.clean && !self.amend {
@@ -119,7 +123,7 @@ impl CommitCommand {
println!("{}", messages.auto_stage_changes().yellow());
repo.stage_all()?;
println!("{}", messages.staged_all().green());
// Re-check status after staging to ensure changes are detected
let new_status = repo.status_summary()?;
if new_status.staged == 0 {
@@ -179,14 +183,22 @@ impl CommitCommand {
let result = if self.amend {
if self.dry_run {
println!("\n{} {}", messages.dry_run(), "- commit not amended.".yellow());
println!(
"\n{} {}",
messages.dry_run(),
"- commit not amended.".yellow()
);
return Ok(());
}
self.amend_commit(&repo, &commit_message)?;
None
} else {
if self.dry_run {
println!("\n{} {}", messages.dry_run(), "- commit not created.".yellow());
println!(
"\n{} {}",
messages.dry_run(),
"- commit not created.".yellow()
);
return Ok(());
}
CommitBuilder::new()
@@ -196,9 +208,13 @@ impl CommitCommand {
};
if let Some(commit_oid) = result {
println!("{} {}", messages.commit_created().green().bold(), commit_oid.to_string()[..8].to_string().cyan());
println!(
"{} {}",
messages.commit_created().green().bold(),
commit_oid.to_string()[..8].to_string().cyan()
);
} else {
println!("{} {}", messages.commit_amended().green().bold(), "successfully");
println!("{} successfully", messages.commit_amended().green().bold());
}
// Push after commit if requested or ask user
@@ -228,8 +244,9 @@ impl CommitCommand {
}
fn create_manual_commit(&self, format: CommitFormat) -> Result<String> {
let description = self.message.clone()
.ok_or_else(|| anyhow::anyhow!("Description required for manual commit. Use -m <message>"))?;
let description = self.message.clone().ok_or_else(|| {
anyhow::anyhow!("Description required for manual commit. Use -m <message>")
})?;
// Try to extract commit type from message if not provided
let commit_type = if let Some(ref ct) = self.commit_type {
@@ -255,31 +272,40 @@ impl CommitCommand {
builder.build_message()
}
async fn generate_commit(&self, repo: &GitRepo, format: CommitFormat, messages: &Messages) -> Result<String> {
async fn generate_commit(
&self,
repo: &GitRepo,
format: CommitFormat,
messages: &Messages,
) -> Result<String> {
let manager = ConfigManager::new()?;
let config = manager.config();
// Check if LLM is configured
let generator = ContentGenerator::new(&config.llm).await
let generator = ContentGenerator::new_with_think(&manager, self.think)
.await
.context("Failed to initialize LLM. Use --manual for manual commit.")?;
println!("{}", messages.ai_analyzing());
let language_str = &config.language.output_language;
let language = Language::from_str(language_str).unwrap_or(Language::English);
let language = manager.get_language().unwrap_or(Language::English);
let generated = if self.yes {
// Non-interactive mode: generate directly
generator.generate_commit_from_repo(repo, format, language).await?
generator
.generate_commit_from_repo(repo, format, language)
.await?
} else {
// Interactive mode: allow user to review and regenerate
generator.generate_commit_interactive(repo, format, language).await?
generator
.generate_commit_interactive(repo, format, language)
.await?
};
Ok(generated.to_conventional())
}
async fn create_interactive_commit(&self, format: CommitFormat, messages: &Messages) -> Result<String> {
async fn create_interactive_commit(
&self,
format: CommitFormat,
messages: &Messages,
) -> Result<String> {
let types = get_commit_types(format == CommitFormat::Commitlint);
// Select type
@@ -357,20 +383,21 @@ impl CommitCommand {
if !output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
let error_msg = if stderr.is_empty() {
if stdout.is_empty() {
"GPG signing failed. Please check:\n\
1. GPG signing key is configured (git config --get user.signingkey)\n\
2. GPG agent is running\n\
3. You can sign commits manually (try: git commit --amend -S)".to_string()
3. You can sign commits manually (try: git commit --amend -S)"
.to_string()
} else {
stdout.to_string()
}
} else {
stderr.to_string()
};
bail!("Failed to amend commit: {}", error_msg);
}
@@ -378,26 +405,26 @@ impl CommitCommand {
}
}
// Helper trait for optional builder methods
trait CommitBuilderExt {
fn scope_opt(self, scope: Option<String>) -> Self;
fn body_opt(self, body: Option<String>) -> Self;
}
// // Helper trait for optional builder methods
// trait CommitBuilderExt {
// fn scope_opt(self, scope: Option<String>) -> Self;
// fn body_opt(self, body: Option<String>) -> Self;
// }
impl CommitBuilderExt for CommitBuilder {
fn scope_opt(self, scope: Option<String>) -> Self {
if let Some(s) = scope {
self.scope(s)
} else {
self
}
}
// impl CommitBuilderExt for CommitBuilder {
// fn scope_opt(self, scope: Option<String>) -> Self {
// if let Some(s) = scope {
// self.scope(s)
// } else {
// self
// }
// }
fn body_opt(self, body: Option<String>) -> Self {
if let Some(b) = body {
self.body(b)
} else {
self
}
}
}
// fn body_opt(self, body: Option<String>) -> Self {
// if let Some(b) = body {
// self.body(b)
// } else {
// self
// }
// }
// }

File diff suppressed because it is too large Load Diff

View File

@@ -4,10 +4,11 @@ use colored::Colorize;
use dialoguer::{Confirm, Input, Select};
use std::path::PathBuf;
use crate::config::{GitProfile, Language};
use crate::config::manager::ConfigManager;
use crate::config::profile::{GpgConfig, SshConfig};
use crate::config::{GitProfile, Language};
use crate::i18n::Messages;
use crate::utils::keyring::{get_default_model, get_supported_providers, provider_needs_api_key};
use crate::utils::validators::validate_email;
/// Initialize quicommit configuration
@@ -27,35 +28,34 @@ impl InitCommand {
let messages = Messages::new(Language::English);
println!("{}", messages.initializing().bold().cyan());
let config_path = config_path.unwrap_or_else(|| {
crate::config::AppConfig::default_path().unwrap()
});
// Check if config already exists
let config_path =
config_path.unwrap_or_else(|| crate::config::AppConfig::default_path().unwrap());
if config_path.exists() && !self.reset {
if !self.yes {
let overwrite = Confirm::new()
.with_prompt("Configuration already exists. Overwrite?")
.default(false)
.interact()?;
if !overwrite {
println!("{}", "Initialization cancelled.".yellow());
return Ok(());
}
} else {
println!("{}", "Configuration already exists. Use --reset to overwrite.".yellow());
println!(
"{}",
"Configuration already exists. Use --reset to overwrite.".yellow()
);
return Ok(());
}
}
// Create parent directory if needed
if let Some(parent) = config_path.parent() {
std::fs::create_dir_all(parent)
.map_err(|e| anyhow::anyhow!("Failed to create config directory: {}", e))?;
}
// Create new config manager with fresh config
let mut manager = ConfigManager::with_path_fresh(&config_path)?;
if self.yes {
@@ -65,11 +65,10 @@ impl InitCommand {
}
manager.save()?;
// Get configured language for final messages
let language = manager.get_language().unwrap_or(Language::English);
let messages = Messages::new(language);
println!("{}", messages.init_success().bold().green());
println!("\n{}: {}", messages.config_file(), config_path.display());
println!("\n{}:", messages.next_steps());
@@ -81,22 +80,20 @@ impl InitCommand {
}
async fn quick_setup(&self, manager: &mut ConfigManager) -> Result<()> {
// Try to get git user info
let git_config = git2::Config::open_default()?;
let user_name = git_config.get_string("user.name").unwrap_or_else(|_| "User".to_string());
let user_email = git_config.get_string("user.email").unwrap_or_else(|_| "user@example.com".to_string());
let profile = GitProfile::new(
"default".to_string(),
user_name,
user_email,
);
let user_name = git_config
.get_string("user.name")
.unwrap_or_else(|_| "User".to_string());
let user_email = git_config
.get_string("user.email")
.unwrap_or_else(|_| "user@example.com".to_string());
let profile = GitProfile::new("default".to_string(), user_name, user_email);
manager.add_profile("default".to_string(), profile)?;
manager.set_default_profile(Some("default".to_string()))?;
// Set default LLM to Ollama
manager.set_llm_provider("ollama".to_string());
Ok(())
@@ -106,9 +103,8 @@ impl InitCommand {
let messages = Messages::new(Language::English);
println!("\n{}", messages.setup_profile().bold());
// Language selection
println!("\n{}", messages.select_output_language().bold());
let languages = vec![
let languages = [
Language::English,
Language::Chinese,
Language::Japanese,
@@ -117,32 +113,31 @@ impl InitCommand {
Language::French,
Language::German,
];
let language_names: Vec<String> = languages.iter().map(|l| l.display_name().to_string()).collect();
let language_idx = Select::new()
.items(&language_names)
.default(0)
.interact()?;
let language_names: Vec<String> = languages
.iter()
.map(|l| l.display_name().to_string())
.collect();
let language_idx = Select::new().items(&language_names).default(0).interact()?;
let selected_language = languages[language_idx];
manager.set_output_language(selected_language.to_code().to_string());
// Update messages to selected language
let messages = Messages::new(selected_language);
// Profile name
let profile_name: String = Input::new()
.with_prompt(messages.profile_name())
.default("personal".to_string())
.interact_text()?;
// User info
let git_config = git2::Config::open_default().ok();
let default_name = git_config.as_ref()
let default_name = git_config
.as_ref()
.and_then(|c| c.get_string("user.name").ok())
.unwrap_or_default();
let default_email = git_config.as_ref()
let default_email = git_config
.as_ref()
.and_then(|c| c.get_string("user.email").ok())
.unwrap_or_default();
@@ -154,9 +149,7 @@ impl InitCommand {
let user_email: String = Input::new()
.with_prompt(messages.git_user_email())
.default(default_email)
.validate_with(|input: &String| {
validate_email(input).map_err(|e| e.to_string())
})
.validate_with(|input: &String| validate_email(input).map_err(|e| e.to_string()))
.interact_text()?;
let description: String = Input::new()
@@ -170,14 +163,15 @@ impl InitCommand {
.interact()?;
let organization = if is_work {
Some(Input::new()
.with_prompt(messages.organization_name())
.interact_text()?)
Some(
Input::new()
.with_prompt(messages.organization_name())
.interact_text()?,
)
} else {
None
};
// SSH configuration
let setup_ssh = Confirm::new()
.with_prompt(messages.configure_ssh())
.default(false)
@@ -189,7 +183,6 @@ impl InitCommand {
None
};
// GPG configuration
let setup_gpg = Confirm::new()
.with_prompt(messages.configure_gpg())
.default(false)
@@ -201,12 +194,7 @@ impl InitCommand {
None
};
// Create profile
let mut profile = GitProfile::new(
profile_name.clone(),
user_name,
user_email,
);
let mut profile = GitProfile::new(profile_name.clone(), user_name, user_email);
if !description.is_empty() {
profile.description = Some(description);
@@ -220,59 +208,112 @@ impl InitCommand {
manager.add_profile(profile_name.clone(), profile)?;
manager.set_default_profile(Some(profile_name))?;
// LLM provider selection
println!("\n{}", messages.select_llm_provider().bold());
let providers = vec![
let provider_display_names = vec![
"Ollama (local)",
"OpenAI",
"Anthropic Claude",
"Kimi (Moonshot AI)",
"DeepSeek",
"OpenRouter"
"OpenRouter",
];
let provider_idx = Select::new()
.items(&providers)
.items(&provider_display_names)
.default(0)
.interact()?;
let provider = match provider_idx {
0 => "ollama",
1 => "openai",
2 => "anthropic",
3 => "kimi",
4 => "deepseek",
5 => "openrouter",
_ => "ollama",
let providers = get_supported_providers();
let provider = providers[provider_idx].to_string();
let keyring = manager.keyring();
let keyring_available = keyring.is_available();
if !keyring_available {
println!(
"\n{}",
"⚠ Keyring is not available on this system.".yellow()
);
println!("{}", keyring.get_status_message().yellow());
}
let api_key = if provider_needs_api_key(&provider) {
let env_key = std::env::var("QUICOMMIT_API_KEY")
.or_else(|_| {
std::env::var(format!("QUICOMMIT_{}_API_KEY", provider.to_uppercase()))
})
.ok();
if let Some(_key) = env_key {
println!(
"\n{} {}",
"".green(),
"Found API key in environment variable.".green()
);
None
} else if keyring_available {
let prompt = match provider.as_str() {
"openai" => messages.openai_api_key(),
"anthropic" => messages.anthropic_api_key(),
"kimi" => messages.kimi_api_key(),
"deepseek" => messages.deepseek_api_key(),
"openrouter" => messages.openrouter_api_key(),
_ => "API Key",
};
let key: String = Input::new().with_prompt(prompt).interact_text()?;
Some(key)
} else {
println!(
"\n{}",
"Please set the QUICOMMIT_API_KEY environment variable.".yellow()
);
None
}
} else {
None
};
manager.set_llm_provider(provider.to_string());
let default_model = get_default_model(&provider);
let model: String = Input::new()
.with_prompt("Model name")
.default(default_model.to_string())
.interact_text()?;
// Configure API key if needed
if provider == "openai" {
let api_key: String = Input::new()
.with_prompt(messages.openai_api_key())
let base_url: Option<String> = if provider == "ollama" {
let url: String = Input::new()
.with_prompt("Ollama server URL")
.default("http://localhost:11434".to_string())
.interact_text()?;
manager.set_openai_api_key(api_key);
} else if provider == "anthropic" {
let api_key: String = Input::new()
.with_prompt(messages.anthropic_api_key())
.interact_text()?;
manager.set_anthropic_api_key(api_key);
} else if provider == "kimi" {
let api_key: String = Input::new()
.with_prompt(messages.kimi_api_key())
.interact_text()?;
manager.set_kimi_api_key(api_key);
} else if provider == "deepseek" {
let api_key: String = Input::new()
.with_prompt(messages.deepseek_api_key())
.interact_text()?;
manager.set_deepseek_api_key(api_key);
} else if provider == "openrouter" {
let api_key: String = Input::new()
.with_prompt(messages.openrouter_api_key())
.interact_text()?;
manager.set_openrouter_api_key(api_key);
Some(url)
} else {
let use_custom_url = Confirm::new()
.with_prompt("Use custom API base URL?")
.default(false)
.interact()?;
if use_custom_url {
let url: String = Input::new().with_prompt("Base URL").interact_text()?;
Some(url)
} else {
None
}
};
manager.set_llm_provider(provider.clone());
manager.set_llm_model(model);
manager.set_llm_base_url(base_url);
if let Some(key) = api_key
&& provider_needs_api_key(&provider)
{
manager.set_api_key(&key)?;
println!(
"\n{} {}",
"".green(),
"API key stored securely in system keyring.".green()
);
}
Ok(())

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
use anyhow::{bail, Result};
use anyhow::{Result, bail};
use clap::Parser;
use colored::Colorize;
use dialoguer::{Confirm, Input, Select};
@@ -6,11 +6,11 @@ use semver::Version;
use std::path::PathBuf;
use crate::config::{Language, manager::ConfigManager};
use crate::git::{find_repo, GitRepo};
use crate::generator::ContentGenerator;
use crate::git::tag::{
bump_version, get_latest_version, suggest_version_bump, TagBuilder, VersionBump,
TagBuilder, VersionBump, bump_version, get_latest_version, suggest_version_bump,
};
use crate::git::{GitRepo, find_repo};
use crate::i18n::Messages;
/// Generate and create Git tags
@@ -56,6 +56,10 @@ pub struct TagCommand {
#[arg(long)]
dry_run: bool,
/// Enable thinking mode for this tag (override config)
#[arg(short = 't', long)]
think: bool,
/// Skip interactive prompts
#[arg(short = 'y', long)]
yes: bool,
@@ -79,30 +83,37 @@ impl TagCommand {
} else if let Some(bump_str) = &self.bump {
// Calculate bumped version
let prefix = &config.tag.version_prefix;
let latest = get_latest_version(&repo, prefix)?
.unwrap_or_else(|| Version::new(0, 0, 0));
let latest =
get_latest_version(&repo, prefix)?.unwrap_or_else(|| Version::new(0, 0, 0));
let bump = VersionBump::from_str(bump_str)?;
let new_version = bump_version(&latest, bump, None);
format!("{}{}", prefix, new_version)
} else {
// Interactive mode
self.select_version_interactive(&repo, &config.tag.version_prefix, &messages).await?
self.select_version_interactive(&repo, &config.tag.version_prefix, &messages)
.await?
};
// Validate tag name (if it looks like a version)
if tag_name.starts_with('v') || tag_name.chars().next().map(|c| c.is_ascii_digit()).unwrap_or(false) {
if tag_name.starts_with('v')
|| tag_name
.chars()
.next()
.map(|c| c.is_ascii_digit())
.unwrap_or(false)
{
let version_str = tag_name.trim_start_matches('v');
if let Err(e) = crate::utils::validators::validate_semver(version_str) {
println!("{}: {}", "Warning".yellow(), e);
if !self.yes {
let proceed = Confirm::new()
.with_prompt("Proceed with this tag name anyway?")
.default(true)
.interact()?;
if !proceed {
bail!("{}", messages.tag_cancelled());
}
@@ -116,7 +127,10 @@ impl TagCommand {
} else if let Some(msg) = &self.message {
Some(msg.clone())
} else if self.generate || (config.tag.auto_generate && !self.yes) {
Some(self.generate_tag_message(&repo, &tag_name, &messages).await?)
Some(
self.generate_tag_message(&repo, &tag_name, &messages)
.await?,
)
} else if !self.yes {
Some(self.input_message_interactive(&tag_name, &messages)?)
} else {
@@ -184,12 +198,17 @@ impl TagCommand {
Ok(())
}
async fn select_version_interactive(&self, repo: &GitRepo, prefix: &str, messages: &Messages) -> Result<String> {
async fn select_version_interactive(
&self,
repo: &GitRepo,
prefix: &str,
messages: &Messages,
) -> Result<String> {
loop {
let latest = get_latest_version(repo, prefix)?;
println!("\n{}", messages.version_selection().bold());
if let Some(ref version) = latest {
println!("{} {}{}", messages.latest_version(), prefix, version);
} else {
@@ -216,36 +235,46 @@ impl TagCommand {
// Auto-detect
let commits = repo.get_commits(50)?;
let bump = suggest_version_bump(&commits);
let version = latest.as_ref()
let version = latest
.as_ref()
.map(|v| bump_version(v, bump, None))
.unwrap_or_else(|| Version::new(0, 1, 0));
println!("{} {:?}{}{}", messages.suggested_bump(), bump, prefix, version);
println!(
"{} {:?}{}{}",
messages.suggested_bump(),
bump,
prefix,
version
);
let confirm = Confirm::new()
.with_prompt(messages.use_this_version())
.default(true)
.interact()?;
if confirm {
return Ok(format!("{}{}", prefix, version));
}
// User rejected, continue the loop
}
1 => {
let version = latest.as_ref()
let version = latest
.as_ref()
.map(|v| bump_version(v, VersionBump::Major, None))
.unwrap_or_else(|| Version::new(1, 0, 0));
return Ok(format!("{}{}", prefix, version));
}
2 => {
let version = latest.as_ref()
let version = latest
.as_ref()
.map(|v| bump_version(v, VersionBump::Minor, None))
.unwrap_or_else(|| Version::new(0, 1, 0));
return Ok(format!("{}{}", prefix, version));
}
3 => {
let version = latest.as_ref()
let version = latest
.as_ref()
.map(|v| bump_version(v, VersionBump::Patch, None))
.unwrap_or_else(|| Version::new(0, 0, 1));
return Ok(format!("{}{}", prefix, version));
@@ -268,12 +297,15 @@ impl TagCommand {
}
}
async fn generate_tag_message(&self, repo: &GitRepo, version: &str, messages: &Messages) -> Result<String> {
async fn generate_tag_message(
&self,
repo: &GitRepo,
version: &str,
messages: &Messages,
) -> Result<String> {
let manager = ConfigManager::new()?;
let config = manager.config();
let language = manager.get_language().unwrap_or(Language::English);
// Get commits since last tag
let tags = repo.get_tags()?;
let commits = if let Some(latest_tag) = tags.first() {
repo.get_commits_between(&latest_tag.name, "HEAD")?
@@ -287,18 +319,20 @@ impl TagCommand {
println!("{}", messages.ai_generating_tag(commits.len()));
let generator = ContentGenerator::new(&config.llm).await?;
generator.generate_tag_message(version, &commits, language).await
let generator = ContentGenerator::new_with_think(&manager, self.think).await?;
generator
.generate_tag_message(version, &commits, language)
.await
}
fn input_message_interactive(&self, version: &str, messages: &Messages) -> Result<String> {
let default_msg = format!("Release {}", version);
let use_editor = Confirm::new()
.with_prompt(messages.open_editor())
.default(false)
.interact()?;
if use_editor {
crate::utils::editor::edit_content(&default_msg)
} else {

View File

@@ -1,6 +1,9 @@
use super::{AppConfig, GitProfile, TokenConfig};
use anyhow::{bail, Context, Result};
use std::collections::HashMap;
use crate::utils::keyring::{
KeyringManager, get_default_base_url, get_default_model, provider_needs_api_key,
};
use anyhow::{Context, Result, bail};
// use std::collections::HashMap;
use std::path::{Path, PathBuf};
/// Configuration manager
@@ -8,6 +11,7 @@ pub struct ConfigManager {
config: AppConfig,
config_path: PathBuf,
modified: bool,
keyring: KeyringManager,
}
impl ConfigManager {
@@ -28,6 +32,7 @@ impl ConfigManager {
config,
config_path: path.to_path_buf(),
modified: false,
keyring: KeyringManager::new(),
})
}
@@ -37,6 +42,7 @@ impl ConfigManager {
config: AppConfig::default(),
config_path: path.to_path_buf(),
modified: true,
keyring: KeyringManager::new(),
})
}
@@ -60,10 +66,10 @@ impl ConfigManager {
Ok(())
}
/// Force save configuration
pub fn force_save(&self) -> Result<()> {
self.config.save(&self.config_path)
}
// /// Force save configuration
// pub fn force_save(&self) -> Result<()> {
// self.config.save(&self.config_path)
// }
/// Get configuration file path
pub fn path(&self) -> &Path {
@@ -87,13 +93,13 @@ impl ConfigManager {
if !self.config.profiles.contains_key(name) {
bail!("Profile '{}' does not exist", name);
}
if self.config.default_profile.as_ref() == Some(&name.to_string()) {
self.config.default_profile = None;
}
self.config.repo_profiles.retain(|_, v| v != name);
self.config.profiles.remove(name);
self.modified = true;
Ok(())
@@ -114,11 +120,11 @@ impl ConfigManager {
self.config.profiles.get(name)
}
/// Get mutable profile
pub fn get_profile_mut(&mut self, name: &str) -> Option<&mut GitProfile> {
self.modified = true;
self.config.profiles.get_mut(name)
}
// /// Get mutable profile
// pub fn get_profile_mut(&mut self, name: &str) -> Option<&mut GitProfile> {
// self.modified = true;
// self.config.profiles.get_mut(name)
// }
/// List all profile names
pub fn list_profiles(&self) -> Vec<&String> {
@@ -132,10 +138,10 @@ impl ConfigManager {
/// Set default profile
pub fn set_default_profile(&mut self, name: Option<String>) -> Result<()> {
if let Some(ref n) = name {
if !self.config.profiles.contains_key(n) {
bail!("Profile '{}' does not exist", n);
}
if let Some(ref n) = name
&& !self.config.profiles.contains_key(n)
{
bail!("Profile '{}' does not exist", n);
}
self.config.default_profile = name;
self.modified = true;
@@ -166,54 +172,138 @@ impl ConfigManager {
}
}
/// Get profile usage statistics
pub fn get_profile_usage(&self, name: &str) -> Option<&super::UsageStats> {
self.config.profiles.get(name).map(|p| &p.usage)
}
// /// Get profile usage statistics
// pub fn get_profile_usage(&self, name: &str) -> Option<&super::UsageStats> {
// self.config.profiles.get(name).map(|p| &p.usage)
// }
// Token management
/// Add a token to a profile
pub fn add_token_to_profile(&mut self, profile_name: &str, service: String, token: TokenConfig) -> Result<()> {
/// Add a token to a profile (stores token in keyring)
pub fn add_token_to_profile(
&mut self,
profile_name: &str,
service: String,
token: TokenConfig,
) -> Result<()> {
if !self.config.profiles.contains_key(profile_name) {
bail!("Profile '{}' does not exist", profile_name);
}
if let Some(profile) = self.config.profiles.get_mut(profile_name) {
profile.add_token(service, token);
self.modified = true;
Ok(())
}
Ok(())
}
/// Store a PAT token in keyring for a profile
pub fn store_pat_for_profile(
&self,
profile_name: &str,
service: &str,
token_value: &str,
) -> Result<()> {
let profile = self
.get_profile(profile_name)
.ok_or_else(|| anyhow::anyhow!("Profile '{}' not found", profile_name))?;
let user_email = &profile.user_email;
self.keyring
.store_pat(profile_name, user_email, service, token_value)
}
/// Get a PAT token from keyring for a profile
pub fn get_pat_for_profile(&self, profile_name: &str, service: &str) -> Result<Option<String>> {
let profile = self
.get_profile(profile_name)
.ok_or_else(|| anyhow::anyhow!("Profile '{}' not found", profile_name))?;
let user_email = &profile.user_email;
self.keyring.get_pat(profile_name, user_email, service)
}
/// Check if a PAT token exists for a profile
pub fn has_pat_for_profile(&self, profile_name: &str, service: &str) -> bool {
if let Some(profile) = self.get_profile(profile_name) {
let user_email = &profile.user_email;
self.keyring.has_pat(profile_name, user_email, service)
} else {
bail!("Profile '{}' does not exist", profile_name);
false
}
}
/// Get a token from a profile
pub fn get_token_from_profile(&self, profile_name: &str, service: &str) -> Option<&TokenConfig> {
self.config.profiles.get(profile_name)?.get_token(service)
}
/// Remove a token from a profile
/// Remove a token from a profile (deletes from keyring)
pub fn remove_token_from_profile(&mut self, profile_name: &str, service: &str) -> Result<()> {
if !self.config.profiles.contains_key(profile_name) {
bail!("Profile '{}' does not exist", profile_name);
}
let user_email = self
.config
.profiles
.get(profile_name)
.unwrap()
.user_email
.clone();
let services: Vec<String> = self
.config
.profiles
.get(profile_name)
.unwrap()
.tokens
.keys()
.cloned()
.collect();
if !services.contains(&service.to_string()) {
bail!(
"Token for service '{}' not found in profile '{}'",
service,
profile_name
);
}
self.keyring
.delete_pat(profile_name, &user_email, service)?;
if let Some(profile) = self.config.profiles.get_mut(profile_name) {
profile.remove_token(service);
self.modified = true;
Ok(())
} else {
bail!("Profile '{}' does not exist", profile_name);
}
Ok(())
}
/// List all tokens in a profile
pub fn list_profile_tokens(&self, profile_name: &str) -> Option<Vec<&String>> {
self.config.profiles.get(profile_name).map(|p| p.tokens.keys().collect())
/// Delete all PAT tokens for a profile (used when removing a profile)
pub fn delete_all_pats_for_profile(&self, profile_name: &str) -> Result<()> {
if let Some(profile) = self.get_profile(profile_name) {
let user_email = &profile.user_email;
let services: Vec<String> = profile.tokens.keys().cloned().collect();
self.keyring
.delete_all_pats_for_profile(profile_name, user_email, &services)?;
}
Ok(())
}
// /// List all tokens in a profile
// pub fn list_profile_tokens(&self, profile_name: &str) -> Option<Vec<&String>> {
// self.config.profiles.get(profile_name).map(|p| p.tokens.keys().collect())
// }
// Repository profile management
/// Get profile for repository
pub fn get_repo_profile(&self, repo_path: &str) -> Option<&GitProfile> {
self.config
.repo_profiles
.get(repo_path)
.and_then(|name| self.config.profiles.get(name))
}
// /// Get profile for repository
// pub fn get_repo_profile(&self, repo_path: &str) -> Option<&GitProfile> {
// self.config
// .repo_profiles
// .get(repo_path)
// .and_then(|name| self.config.profiles.get(name))
// }
/// Set profile for repository
pub fn set_repo_profile(&mut self, repo_path: String, profile_name: String) -> Result<()> {
@@ -225,32 +315,75 @@ impl ConfigManager {
Ok(())
}
/// Remove repository profile mapping
pub fn remove_repo_profile(&mut self, repo_path: &str) {
self.config.repo_profiles.remove(repo_path);
self.modified = true;
// /// Remove repository profile mapping
// pub fn remove_repo_profile(&mut self, repo_path: &str) {
// self.config.repo_profiles.remove(repo_path);
// self.modified = true;
// }
// /// List repository profile mappings
// pub fn list_repo_profiles(&self) -> &HashMap<String, String> {
// &self.config.repo_profiles
// }
// /// Get effective profile for a repository (repo-specific -> default)
// pub fn get_effective_profile(&self, repo_path: Option<&str>) -> Option<&GitProfile> {
// if let Some(path) = repo_path {
// if let Some(profile) = self.get_repo_profile(path) {
// return Some(profile);
// }
// }
// self.default_profile()
// }
/// Check and compare profile with git configuration
pub fn check_profile_config(
&self,
profile_name: &str,
repo: &git2::Repository,
) -> Result<super::ProfileComparison> {
let profile = self
.get_profile(profile_name)
.ok_or_else(|| anyhow::anyhow!("Profile '{}' not found", profile_name))?;
profile.compare_with_git_config(repo)
}
/// List repository profile mappings
pub fn list_repo_profiles(&self) -> &HashMap<String, String> {
&self.config.repo_profiles
}
/// Find a profile that matches the given user config (name, email, signing_key)
pub fn find_matching_profile(
&self,
user_name: &str,
user_email: &str,
signing_key: Option<&str>,
) -> Option<&GitProfile> {
for profile in self.config.profiles.values() {
let name_match = profile.user_name == user_name;
let email_match = profile.user_email == user_email;
let key_match = match (signing_key, profile.signing_key()) {
(Some(git_key), Some(profile_key)) => git_key == profile_key,
(None, None) => true,
(Some(_), None) => false,
(None, Some(_)) => false,
};
/// Get effective profile for a repository (repo-specific -> default)
pub fn get_effective_profile(&self, repo_path: Option<&str>) -> Option<&GitProfile> {
if let Some(path) = repo_path {
if let Some(profile) = self.get_repo_profile(path) {
if name_match && email_match && key_match {
return Some(profile);
}
}
self.default_profile()
None
}
/// Check and compare profile with git configuration
pub fn check_profile_config(&self, profile_name: &str, repo: &git2::Repository) -> Result<super::ProfileComparison> {
let profile = self.get_profile(profile_name)
.ok_or_else(|| anyhow::anyhow!("Profile '{}' not found", profile_name))?;
profile.compare_with_git_config(repo)
/// Find profiles that partially match (same name or same email)
pub fn find_partial_matches(&self, user_name: &str, user_email: &str) -> Vec<&GitProfile> {
self.config
.profiles
.values()
.filter(|p| p.user_name == user_name || p.user_email == user_email)
.collect()
}
/// Get repo profile mapping
pub fn get_repo_profile_name(&self, repo_path: &str) -> Option<&String> {
self.config.repo_profiles.get(repo_path)
}
// LLM configuration
@@ -262,104 +395,169 @@ impl ConfigManager {
/// Set LLM provider
pub fn set_llm_provider(&mut self, provider: String) {
self.config.llm.provider = provider;
let default_model = get_default_model(&provider);
self.config.llm.provider = provider.clone();
if self.config.llm.model.is_empty() || self.config.llm.model == "llama2" {
self.config.llm.model = default_model.to_string();
}
self.modified = true;
}
/// Get OpenAI API key
pub fn openai_api_key(&self) -> Option<&String> {
self.config.llm.openai.api_key.as_ref()
/// Get model
pub fn llm_model(&self) -> &str {
&self.config.llm.model
}
/// Set OpenAI API key
pub fn set_openai_api_key(&mut self, key: String) {
self.config.llm.openai.api_key = Some(key);
/// Set model
pub fn set_llm_model(&mut self, model: String) {
self.config.llm.model = model;
self.modified = true;
}
/// Get Anthropic API key
pub fn anthropic_api_key(&self) -> Option<&String> {
self.config.llm.anthropic.api_key.as_ref()
/// Get base URL (returns provider default if not set)
pub fn llm_base_url(&self) -> String {
match &self.config.llm.base_url {
Some(url) => url.clone(),
None => get_default_base_url(&self.config.llm.provider).to_string(),
}
}
/// Set Anthropic API key
pub fn set_anthropic_api_key(&mut self, key: String) {
self.config.llm.anthropic.api_key = Some(key);
/// Set base URL
pub fn set_llm_base_url(&mut self, url: Option<String>) {
self.config.llm.base_url = url;
self.modified = true;
}
/// Get Kimi API key
pub fn kimi_api_key(&self) -> Option<&String> {
self.config.llm.kimi.api_key.as_ref()
/// Get API key from configured storage method
pub fn get_api_key(&self) -> Option<String> {
// First try environment variables (always checked)
if let Some(key) = self
.keyring
.get_api_key(&self.config.llm.provider)
.unwrap_or(None)
{
return Some(key);
}
// Then try config file if configured
if self.config.llm.api_key_storage == "config" {
return self.config.llm.api_key.clone();
}
None
}
/// Set Kimi API key
pub fn set_kimi_api_key(&mut self, key: String) {
self.config.llm.kimi.api_key = Some(key);
self.modified = true;
/// Store API key in configured storage method
pub fn set_api_key(&self, api_key: &str) -> Result<()> {
match self.config.llm.api_key_storage.as_str() {
"keyring" => {
if !self.keyring.is_available() {
bail!(
"Keyring is not available. Set QUICOMMIT_API_KEY environment variable instead or change api_key_storage to 'config'."
);
}
self.keyring
.store_api_key(&self.config.llm.provider, api_key)
}
"config" => {
// We can't modify self.config here since self is immutable
// This will be handled by the caller updating the config
Ok(())
}
"environment" => {
bail!(
"API key storage set to 'environment'. Please set QUICOMMIT_{}_API_KEY environment variable.",
self.config.llm.provider.to_uppercase()
);
}
_ => {
bail!(
"Invalid API key storage method: {}",
self.config.llm.api_key_storage
);
}
}
}
/// Get Kimi base URL
pub fn kimi_base_url(&self) -> &str {
&self.config.llm.kimi.base_url
/// Delete API key from configured storage method
pub fn delete_api_key(&self) -> Result<()> {
match self.config.llm.api_key_storage.as_str() {
"keyring" => {
if self.keyring.is_available() {
self.keyring.delete_api_key(&self.config.llm.provider)?;
}
}
"config" => {
// We can't modify self.config here since self is immutable
// This will be handled by the caller updating the config
}
"environment" => {
// Environment variables are not managed by the app
}
_ => {
bail!(
"Invalid API key storage method: {}",
self.config.llm.api_key_storage
);
}
}
Ok(())
}
/// Set Kimi base URL
pub fn set_kimi_base_url(&mut self, url: String) {
self.config.llm.kimi.base_url = url;
self.modified = true;
/// Check if API key is configured
pub fn has_api_key(&self) -> bool {
if !provider_needs_api_key(&self.config.llm.provider) {
return true;
}
// Check environment variables
if self
.keyring
.get_api_key(&self.config.llm.provider)
.unwrap_or(None)
.is_some()
{
return true;
}
// Check config file if configured
if self.config.llm.api_key_storage == "config" {
return self.config.llm.api_key.is_some();
}
false
}
/// Get DeepSeek API key
pub fn deepseek_api_key(&self) -> Option<&String> {
self.config.llm.deepseek.api_key.as_ref()
/// Get keyring manager reference
pub fn keyring(&self) -> &KeyringManager {
&self.keyring
}
/// Set DeepSeek API key
pub fn set_deepseek_api_key(&mut self, key: String) {
self.config.llm.deepseek.api_key = Some(key);
self.modified = true;
}
// /// Configure LLM provider with all settings
// pub fn configure_llm(&mut self, provider: String, model: Option<String>, base_url: Option<String>, api_key: Option<&str>) -> Result<()> {
// self.set_llm_provider(provider.clone());
/// Get DeepSeek base URL
pub fn deepseek_base_url(&self) -> &str {
&self.config.llm.deepseek.base_url
}
// if let Some(m) = model {
// self.set_llm_model(m);
// }
/// Set DeepSeek base URL
pub fn set_deepseek_base_url(&mut self, url: String) {
self.config.llm.deepseek.base_url = url;
self.modified = true;
}
// self.set_llm_base_url(base_url);
/// Get OpenRouter API key
pub fn openrouter_api_key(&self) -> Option<&String> {
self.config.llm.openrouter.api_key.as_ref()
}
// if let Some(key) = api_key {
// if provider_needs_api_key(&provider) {
// self.set_api_key(key)?;
// }
// }
/// Set OpenRouter API key
pub fn set_openrouter_api_key(&mut self, key: String) {
self.config.llm.openrouter.api_key = Some(key);
self.modified = true;
}
/// Get OpenRouter base URL
pub fn openrouter_base_url(&self) -> &str {
&self.config.llm.openrouter.base_url
}
/// Set OpenRouter base URL
pub fn set_openrouter_base_url(&mut self, url: String) {
self.config.llm.openrouter.base_url = url;
self.modified = true;
}
// Ok(())
// }
// Commit configuration
/// Get commit format
pub fn commit_format(&self) -> super::CommitFormat {
self.config.commit.format
}
// /// Get commit format
// pub fn commit_format(&self) -> super::CommitFormat {
// self.config.commit.format
// }
/// Set commit format
pub fn set_commit_format(&mut self, format: super::CommitFormat) {
@@ -367,10 +565,10 @@ impl ConfigManager {
self.modified = true;
}
/// Check if auto-generate is enabled
pub fn auto_generate_commits(&self) -> bool {
self.config.commit.auto_generate
}
// /// Check if auto-generate is enabled
// pub fn auto_generate_commits(&self) -> bool {
// self.config.commit.auto_generate
// }
/// Set auto-generate commits
pub fn set_auto_generate_commits(&mut self, enabled: bool) {
@@ -380,10 +578,10 @@ impl ConfigManager {
// Tag configuration
/// Get version prefix
pub fn version_prefix(&self) -> &str {
&self.config.tag.version_prefix
}
// /// Get version prefix
// pub fn version_prefix(&self) -> &str {
// &self.config.tag.version_prefix
// }
/// Set version prefix
pub fn set_version_prefix(&mut self, prefix: String) {
@@ -393,10 +591,10 @@ impl ConfigManager {
// Changelog configuration
/// Get changelog path
pub fn changelog_path(&self) -> &str {
&self.config.changelog.path
}
// /// Get changelog path
// pub fn changelog_path(&self) -> &str {
// &self.config.changelog.path
// }
/// Set changelog path
pub fn set_changelog_path(&mut self, path: String) {
@@ -406,10 +604,10 @@ impl ConfigManager {
// Language configuration
/// Get output language
pub fn output_language(&self) -> &str {
&self.config.language.output_language
}
// /// Get output language
// pub fn output_language(&self) -> &str {
// &self.config.language.output_language
// }
/// Set output language
pub fn set_output_language(&mut self, language: String) {
@@ -446,14 +644,12 @@ impl ConfigManager {
/// Export configuration to TOML string
pub fn export(&self) -> Result<String> {
toml::to_string_pretty(&self.config)
.context("Failed to serialize config")
toml::to_string_pretty(&self.config).context("Failed to serialize config")
}
/// Import configuration from TOML string
pub fn import(&mut self, toml_str: &str) -> Result<()> {
self.config = toml::from_str(toml_str)
.context("Failed to parse config")?;
self.config = toml::from_str(toml_str).context("Failed to parse config")?;
self.modified = true;
Ok(())
}
@@ -471,6 +667,7 @@ impl Default for ConfigManager {
config: AppConfig::default(),
config_path: PathBuf::new(),
modified: false,
keyring: KeyringManager::new(),
}
}
}

View File

@@ -7,10 +7,7 @@ use std::path::{Path, PathBuf};
pub mod manager;
pub mod profile;
pub use profile::{
GitProfile, TokenConfig, TokenType,
UsageStats, ProfileComparison
};
pub use profile::{GitProfile, ProfileComparison, TokenConfig, TokenType};
/// Application configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -80,37 +77,16 @@ impl Default for AppConfig {
/// LLM configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmConfig {
/// Default LLM provider
/// Current LLM provider (ollama, openai, anthropic, kimi, deepseek, openrouter)
#[serde(default = "default_llm_provider")]
pub provider: String,
/// OpenAI configuration
#[serde(default)]
pub openai: OpenAiConfig,
/// Model to use (stored in config, not in keyring)
#[serde(default = "default_model")]
pub model: String,
/// Ollama configuration
#[serde(default)]
pub ollama: OllamaConfig,
/// Anthropic Claude configuration
#[serde(default)]
pub anthropic: AnthropicConfig,
/// Kimi (Moonshot AI) configuration
#[serde(default)]
pub kimi: KimiConfig,
/// DeepSeek configuration
#[serde(default)]
pub deepseek: DeepSeekConfig,
/// OpenRouter configuration
#[serde(default)]
pub openrouter: OpenRouterConfig,
/// Custom API configuration
#[serde(default)]
pub custom: Option<CustomLlmConfig>,
/// API base URL (optional, will use provider default if not set)
pub base_url: Option<String>,
/// Maximum tokens for generation
#[serde(default = "default_max_tokens")]
@@ -123,186 +99,45 @@ pub struct LlmConfig {
/// Timeout in seconds
#[serde(default = "default_timeout")]
pub timeout: u64,
/// API key storage method (keyring, config, environment)
#[serde(default = "default_api_key_storage")]
pub api_key_storage: String,
/// API key (stored in config for fallback, encrypted if encrypt_sensitive is true)
#[serde(default)]
pub api_key: Option<String>,
/// Enable thinking/reasoning mode (deepseek, kimi, anthropic)
#[serde(default)]
pub thinking_enabled: bool,
/// Budget tokens for thinking mode (Anthropic Claude 4)
#[serde(default)]
pub thinking_budget_tokens: Option<u32>,
}
fn default_api_key_storage() -> String {
"keyring".to_string()
}
impl Default for LlmConfig {
fn default() -> Self {
Self {
provider: default_llm_provider(),
openai: OpenAiConfig::default(),
ollama: OllamaConfig::default(),
anthropic: AnthropicConfig::default(),
kimi: KimiConfig::default(),
deepseek: DeepSeekConfig::default(),
openrouter: OpenRouterConfig::default(),
custom: None,
model: default_model(),
base_url: None,
max_tokens: default_max_tokens(),
temperature: default_temperature(),
timeout: default_timeout(),
}
}
}
/// OpenAI API configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAiConfig {
/// API key
pub api_key: Option<String>,
/// Model to use
#[serde(default = "default_openai_model")]
pub model: String,
/// API base URL (for custom endpoints)
#[serde(default = "default_openai_base_url")]
pub base_url: String,
}
impl Default for OpenAiConfig {
fn default() -> Self {
Self {
api_key_storage: default_api_key_storage(),
api_key: None,
model: default_openai_model(),
base_url: default_openai_base_url(),
thinking_enabled: false,
thinking_budget_tokens: None,
}
}
}
/// Ollama configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaConfig {
/// Ollama server URL
#[serde(default = "default_ollama_url")]
pub url: String,
/// Model to use
#[serde(default = "default_ollama_model")]
pub model: String,
}
impl Default for OllamaConfig {
fn default() -> Self {
Self {
url: default_ollama_url(),
model: default_ollama_model(),
}
}
}
/// Anthropic Claude configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnthropicConfig {
/// API key
pub api_key: Option<String>,
/// Model to use
#[serde(default = "default_anthropic_model")]
pub model: String,
}
impl Default for AnthropicConfig {
fn default() -> Self {
Self {
api_key: None,
model: default_anthropic_model(),
}
}
}
/// Kimi (Moonshot AI) configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KimiConfig {
/// API key
pub api_key: Option<String>,
/// Model to use
#[serde(default = "default_kimi_model")]
pub model: String,
/// API base URL (for custom endpoints)
#[serde(default = "default_kimi_base_url")]
pub base_url: String,
}
impl Default for KimiConfig {
fn default() -> Self {
Self {
api_key: None,
model: default_kimi_model(),
base_url: default_kimi_base_url(),
}
}
}
/// DeepSeek configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeepSeekConfig {
/// API key
pub api_key: Option<String>,
/// Model to use
#[serde(default = "default_deepseek_model")]
pub model: String,
/// API base URL (for custom endpoints)
#[serde(default = "default_deepseek_base_url")]
pub base_url: String,
}
impl Default for DeepSeekConfig {
fn default() -> Self {
Self {
api_key: None,
model: default_deepseek_model(),
base_url: default_deepseek_base_url(),
}
}
}
/// OpenRouter configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenRouterConfig {
/// API key
pub api_key: Option<String>,
/// Model to use
#[serde(default = "default_openrouter_model")]
pub model: String,
/// API base URL (for custom endpoints)
#[serde(default = "default_openrouter_base_url")]
pub base_url: String,
}
impl Default for OpenRouterConfig {
fn default() -> Self {
Self {
api_key: None,
model: default_openrouter_model(),
base_url: default_openrouter_base_url(),
}
}
}
/// Custom LLM API configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CustomLlmConfig {
/// API endpoint URL
pub url: String,
/// API key (optional)
pub api_key: Option<String>,
/// Model name
pub model: String,
/// Request format template (JSON)
pub request_template: String,
/// Response path to extract content (e.g., "choices.0.message.content")
pub response_path: String,
}
/// Commit configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommitConfig {
@@ -592,6 +427,10 @@ fn default_llm_provider() -> String {
"ollama".to_string()
}
fn default_model() -> String {
"llama2".to_string()
}
fn default_max_tokens() -> u32 {
500
}
@@ -604,50 +443,6 @@ fn default_timeout() -> u64 {
30
}
fn default_openai_model() -> String {
"gpt-4".to_string()
}
fn default_openai_base_url() -> String {
"https://api.openai.com/v1".to_string()
}
fn default_ollama_url() -> String {
"http://localhost:11434".to_string()
}
fn default_ollama_model() -> String {
"llama2".to_string()
}
fn default_anthropic_model() -> String {
"claude-3-sonnet-20240229".to_string()
}
fn default_kimi_model() -> String {
"moonshot-v1-8k".to_string()
}
fn default_kimi_base_url() -> String {
"https://api.moonshot.cn/v1".to_string()
}
fn default_deepseek_model() -> String {
"deepseek-chat".to_string()
}
fn default_deepseek_base_url() -> String {
"https://api.deepseek.com/v1".to_string()
}
fn default_openrouter_model() -> String {
"openai/gpt-3.5-turbo".to_string()
}
fn default_openrouter_base_url() -> String {
"https://openrouter.ai/api/v1".to_string()
}
fn default_commit_format() -> CommitFormat {
CommitFormat::Conventional
}
@@ -696,39 +491,89 @@ impl AppConfig {
/// Save configuration to file
pub fn save(&self, path: &Path) -> Result<()> {
let content = toml::to_string_pretty(self)
.context("Failed to serialize config")?;
let content = toml::to_string_pretty(self).context("Failed to serialize config")?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("Failed to create config directory: {:?}", parent))?;
}
fs::write(path, content)
.with_context(|| format!("Failed to write config file: {:?}", path))?;
Ok(())
}
/// Get default config path
pub fn default_path() -> Result<PathBuf> {
let config_dir = dirs::config_dir()
.context("Could not find config directory")?;
let config_dir = dirs::config_dir().context("Could not find config directory")?;
Ok(config_dir.join("quicommit").join("config.toml"))
}
/// Get profile for a repository
pub fn get_profile_for_repo(&self, repo_path: &str) -> Option<&GitProfile> {
let profile_name = self.repo_profiles.get(repo_path)?;
self.profiles.get(profile_name)
// /// Get profile for a repository
// pub fn get_profile_for_repo(&self, repo_path: &str) -> Option<&GitProfile> {
// let profile_name = self.repo_profiles.get(repo_path)?;
// self.profiles.get(profile_name)
// }
// /// Set profile for a repository
// pub fn set_profile_for_repo(&mut self, repo_path: String, profile_name: String) -> Result<()> {
// if !self.profiles.contains_key(&profile_name) {
// anyhow::bail!("Profile '{}' does not exist", profile_name);
// }
// self.repo_profiles.insert(repo_path, profile_name);
// Ok(())
// }
}
/// Encrypted PAT data for export
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedPat {
/// Profile name
pub profile_name: String,
/// Service name (e.g., github, gitlab)
pub service: String,
/// User email (for keyring lookup)
pub user_email: String,
/// Encrypted token value
pub encrypted_token: String,
}
/// Export data container with optional encrypted PATs
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExportData {
/// Configuration content (TOML string)
pub config: String,
/// Encrypted PATs (only present when exporting with encryption)
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub encrypted_pats: Vec<EncryptedPat>,
/// Export version for future compatibility
#[serde(default = "default_export_version")]
pub export_version: String,
}
fn default_export_version() -> String {
"1".to_string()
}
impl ExportData {
pub fn new(config: String) -> Self {
Self {
config,
encrypted_pats: Vec::new(),
export_version: default_export_version(),
}
}
/// Set profile for a repository
pub fn set_profile_for_repo(&mut self, repo_path: String, profile_name: String) -> Result<()> {
if !self.profiles.contains_key(&profile_name) {
anyhow::bail!("Profile '{}' does not exist", profile_name);
pub fn with_encrypted_pats(config: String, pats: Vec<EncryptedPat>) -> Self {
Self {
config,
encrypted_pats: pats,
export_version: default_export_version(),
}
self.repo_profiles.insert(repo_path, profile_name);
Ok(())
}
pub fn has_encrypted_pats(&self) -> bool {
!self.encrypted_pats.is_empty()
}
}

View File

@@ -1,4 +1,4 @@
use anyhow::{bail, Result};
use anyhow::{Result, bail};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
@@ -80,25 +80,25 @@ impl GitProfile {
if self.user_name.is_empty() {
bail!("User name cannot be empty");
}
if self.user_email.is_empty() {
bail!("User email cannot be empty");
}
crate::utils::validators::validate_email(&self.user_email)?;
if let Some(ref ssh) = self.ssh {
ssh.validate()?;
}
if let Some(ref gpg) = self.gpg {
gpg.validate()?;
}
for token in self.tokens.values() {
token.validate()?;
}
Ok(())
}
@@ -120,8 +120,7 @@ impl GitProfile {
/// Get signing key (from GPG config or direct)
pub fn signing_key(&self) -> Option<&str> {
self.signing_key
.as_ref()
.map(|s| s.as_str())
.as_deref()
.or_else(|| self.gpg.as_ref().map(|g| g.key_id.as_str()))
}
@@ -144,7 +143,7 @@ impl GitProfile {
pub fn record_usage(&mut self, repo_path: Option<String>) {
self.usage.last_used = Some(chrono::Utc::now().to_rfc3339());
self.usage.total_uses += 1;
if let Some(repo) = repo_path {
let count = self.usage.repo_usage.entry(repo).or_insert(0);
*count += 1;
@@ -159,93 +158,95 @@ impl GitProfile {
/// Apply this profile to a git repository (local config)
pub fn apply_to_repo(&self, repo: &git2::Repository) -> Result<()> {
let mut config = repo.config()?;
config.set_str("user.name", &self.user_name)?;
config.set_str("user.email", &self.user_email)?;
if let Some(key) = self.signing_key() {
config.set_str("user.signingkey", key)?;
if self.settings.auto_sign_commits {
config.set_bool("commit.gpgsign", true)?;
}
if self.settings.auto_sign_tags {
config.set_bool("tag.gpgsign", true)?;
}
}
if let Some(ref ssh) = self.ssh {
if let Some(ref key_path) = ssh.private_key_path {
let path_str = key_path.display().to_string();
#[cfg(target_os = "windows")]
{
config.set_str("core.sshCommand",
&format!("ssh -i \"{}\"", path_str.replace('\\', "/")))?;
}
#[cfg(not(target_os = "windows"))]
{
config.set_str("core.sshCommand",
&format!("ssh -i '{}'", path_str))?;
}
if let Some(ref ssh) = self.ssh
&& let Some(ref key_path) = ssh.private_key_path
{
let path_str = key_path.display().to_string();
#[cfg(target_os = "windows")]
{
config.set_str(
"core.sshCommand",
&format!("ssh -i \"{}\"", path_str.replace('\\', "/")),
)?;
}
#[cfg(not(target_os = "windows"))]
{
config.set_str("core.sshCommand", &format!("ssh -i '{}'", path_str))?;
}
}
Ok(())
}
/// Apply this profile globally
pub fn apply_global(&self) -> Result<()> {
let mut config = git2::Config::open_default()?;
config.set_str("user.name", &self.user_name)?;
config.set_str("user.email", &self.user_email)?;
if let Some(key) = self.signing_key() {
config.set_str("user.signingkey", key)?;
if self.settings.auto_sign_commits {
config.set_bool("commit.gpgsign", true)?;
}
if self.settings.auto_sign_tags {
config.set_bool("tag.gpgsign", true)?;
}
}
if let Some(ref ssh) = self.ssh {
if let Some(ref key_path) = ssh.private_key_path {
let path_str = key_path.display().to_string();
#[cfg(target_os = "windows")]
{
config.set_str("core.sshCommand",
&format!("ssh -i \"{}\"", path_str.replace('\\', "/")))?;
}
#[cfg(not(target_os = "windows"))]
{
config.set_str("core.sshCommand",
&format!("ssh -i '{}'", path_str))?;
}
if let Some(ref ssh) = self.ssh
&& let Some(ref key_path) = ssh.private_key_path
{
let path_str = key_path.display().to_string();
#[cfg(target_os = "windows")]
{
config.set_str(
"core.sshCommand",
&format!("ssh -i \"{}\"", path_str.replace('\\', "/")),
)?;
}
#[cfg(not(target_os = "windows"))]
{
config.set_str("core.sshCommand", &format!("ssh -i '{}'", path_str))?;
}
}
Ok(())
}
/// Compare with current git configuration
pub fn compare_with_git_config(&self, repo: &git2::Repository) -> Result<ProfileComparison> {
let config = repo.config()?;
let git_user_name = config.get_string("user.name").ok();
let git_user_email = config.get_string("user.email").ok();
let git_signing_key = config.get_string("user.signingkey").ok();
let mut comparison = ProfileComparison {
profile_name: self.name.clone(),
matches: true,
differences: vec![],
};
if git_user_name.as_deref() != Some(&self.user_name) {
comparison.matches = false;
comparison.differences.push(ConfigDifference {
@@ -254,7 +255,7 @@ impl GitProfile {
git_value: git_user_name.unwrap_or_else(|| "<not set>".to_string()),
});
}
if git_user_email.as_deref() != Some(&self.user_email) {
comparison.matches = false;
comparison.differences.push(ConfigDifference {
@@ -263,24 +264,24 @@ impl GitProfile {
git_value: git_user_email.unwrap_or_else(|| "<not set>".to_string()),
});
}
if let Some(profile_key) = self.signing_key() {
if git_signing_key.as_deref() != Some(profile_key) {
comparison.matches = false;
comparison.differences.push(ConfigDifference {
key: "user.signingkey".to_string(),
profile_value: profile_key.to_string(),
git_value: git_signing_key.unwrap_or_else(|| "<not set>".to_string()),
});
}
if let Some(profile_key) = self.signing_key()
&& git_signing_key.as_deref() != Some(profile_key)
{
comparison.matches = false;
comparison.differences.push(ConfigDifference {
key: "user.signingkey".to_string(),
profile_value: profile_key.to_string(),
git_value: git_signing_key.unwrap_or_else(|| "<not set>".to_string()),
});
}
Ok(comparison)
}
}
/// Profile settings
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ProfileSettings {
/// Automatically sign commits
#[serde(default)]
@@ -307,19 +308,6 @@ pub struct ProfileSettings {
pub commit_template: Option<String>,
}
impl Default for ProfileSettings {
fn default() -> Self {
Self {
auto_sign_commits: false,
auto_sign_tags: false,
default_commit_format: None,
repo_patterns: vec![],
llm_provider: None,
commit_template: None,
}
}
}
/// SSH configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SshConfig {
@@ -349,18 +337,18 @@ pub struct SshConfig {
impl SshConfig {
/// Validate SSH configuration
pub fn validate(&self) -> Result<()> {
if let Some(ref path) = self.private_key_path {
if !path.exists() {
bail!("SSH private key does not exist: {:?}", path);
}
if let Some(ref path) = self.private_key_path
&& !path.exists()
{
bail!("SSH private key does not exist: {:?}", path);
}
if let Some(ref path) = self.public_key_path {
if !path.exists() {
bail!("SSH public key does not exist: {:?}", path);
}
if let Some(ref path) = self.public_key_path
&& !path.exists()
{
bail!("SSH public key does not exist: {:?}", path);
}
Ok(())
}
@@ -423,10 +411,6 @@ impl GpgConfig {
/// Token configuration for services (GitHub, GitLab, etc.)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenConfig {
/// Token value (encrypted)
#[serde(skip_serializing_if = "Option::is_none")]
pub token: Option<String>,
/// Token type (personal, oauth, etc.)
#[serde(default)]
pub token_type: TokenType,
@@ -446,25 +430,41 @@ pub struct TokenConfig {
/// Description
#[serde(default)]
pub description: Option<String>,
/// Indicates if a token is stored in keyring
#[serde(default)]
pub has_token: bool,
}
impl TokenConfig {
/// Create a new token config
pub fn new(token: String, token_type: TokenType) -> Self {
/// Create a new token config (token stored separately in keyring)
pub fn new(token_type: TokenType) -> Self {
Self {
token: Some(token),
token_type,
scopes: vec![],
expires_at: None,
last_used: None,
description: None,
has_token: true,
}
}
/// Create a new token config without token
pub fn without_token(token_type: TokenType) -> Self {
Self {
token_type,
scopes: vec![],
expires_at: None,
last_used: None,
description: None,
has_token: false,
}
}
/// Validate token configuration
pub fn validate(&self) -> Result<()> {
if self.token.is_none() && self.token_type != TokenType::None {
bail!("Token value is required for {:?}", self.token_type);
if !self.has_token && self.token_type != TokenType::None {
bail!("Token is required for {:?}", self.token_type);
}
Ok(())
}
@@ -473,12 +473,19 @@ impl TokenConfig {
pub fn record_usage(&mut self) {
self.last_used = Some(chrono::Utc::now().to_rfc3339());
}
/// Mark that a token is stored
pub fn set_has_token(&mut self, has_token: bool) {
self.has_token = has_token;
}
}
/// Token type
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum TokenType {
#[default]
None,
Personal,
OAuth,
@@ -486,12 +493,6 @@ pub enum TokenType {
App,
}
impl Default for TokenType {
fn default() -> Self {
Self::None
}
}
impl std::fmt::Display for TokenType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@@ -621,9 +622,15 @@ impl GitProfileBuilder {
}
pub fn build(self) -> Result<GitProfile> {
let name = self.name.ok_or_else(|| anyhow::anyhow!("Name is required"))?;
let user_name = self.user_name.ok_or_else(|| anyhow::anyhow!("User name is required"))?;
let user_email = self.user_email.ok_or_else(|| anyhow::anyhow!("User email is required"))?;
let name = self
.name
.ok_or_else(|| anyhow::anyhow!("Name is required"))?;
let user_name = self
.user_name
.ok_or_else(|| anyhow::anyhow!("User name is required"))?;
let user_email = self
.user_email
.ok_or_else(|| anyhow::anyhow!("User email is required"))?;
Ok(GitProfile {
name,
@@ -669,13 +676,13 @@ mod tests {
"".to_string(),
"invalid-email".to_string(),
);
assert!(profile.validate().is_err());
}
#[test]
fn test_token_config() {
let token = TokenConfig::new("test-token".to_string(), TokenType::Personal);
let token = TokenConfig::new(TokenType::Personal);
assert!(token.validate().is_ok());
}
}

View File

@@ -1,4 +1,5 @@
use crate::config::{CommitFormat, LlmConfig, Language};
use crate::config::manager::ConfigManager;
use crate::config::{CommitFormat, Language};
use crate::git::{CommitInfo, GitRepo};
use crate::llm::{GeneratedCommit, LlmClient};
use anyhow::{Context, Result};
@@ -10,17 +11,44 @@ pub struct ContentGenerator {
impl ContentGenerator {
/// Create new content generator
pub async fn new(config: &LlmConfig) -> Result<Self> {
let llm_client = LlmClient::from_config(config).await?;
// Check if provider is available
if !llm_client.is_available().await {
anyhow::bail!("LLM provider '{}' is not available", config.provider);
pub async fn new(manager: &ConfigManager) -> Result<Self> {
Self::new_with_think(manager, false).await
}
/// Create new content generator with thinking override
pub async fn new_with_think(manager: &ConfigManager, think_override: bool) -> Result<Self> {
let mut thinking_enabled = if think_override {
true
} else {
manager.config().llm.thinking_enabled
};
// Validate thinking support per provider
if thinking_enabled {
let provider = manager.llm_provider();
if !Self::supports_thinking(provider) {
eprintln!(
"Warning: Provider '{}' does not support thinking mode. \
Disabling thinking for this invocation.",
provider
);
thinking_enabled = false;
}
}
let llm_client = LlmClient::from_config_with_think(manager, thinking_enabled).await?;
if !llm_client.is_available().await {
anyhow::bail!("LLM provider '{}' is not available", manager.llm_provider());
}
Ok(Self { llm_client })
}
fn supports_thinking(provider: &str) -> bool {
matches!(provider, "deepseek" | "kimi" | "anthropic" | "openai")
}
/// Generate commit message from diff
pub async fn generate_commit_message(
&self,
@@ -31,12 +59,15 @@ impl ContentGenerator {
// Truncate diff if too long
let max_diff_len = 4000;
let truncated_diff = if diff.len() > max_diff_len {
format!("{}\n... (truncated)", &diff[..max_diff_len])
let boundary = diff.floor_char_boundary(max_diff_len);
format!("{}\n... (truncated)", &diff[..boundary])
} else {
diff.to_string()
};
self.llm_client.generate_commit_message(&truncated_diff, format, language).await
self.llm_client
.generate_commit_message(&truncated_diff, format, language)
.await
}
/// Generate commit message from repository changes
@@ -46,13 +77,14 @@ impl ContentGenerator {
format: CommitFormat,
language: Language,
) -> Result<GeneratedCommit> {
let diff = repo.get_staged_diff()
let diff = repo
.get_staged_diff_sorted()
.context("Failed to get staged diff")?;
if diff.is_empty() {
anyhow::bail!("No staged changes to generate commit from");
}
self.generate_commit_message(&diff, format, language).await
}
@@ -63,12 +95,12 @@ impl ContentGenerator {
commits: &[CommitInfo],
language: Language,
) -> Result<String> {
let commit_messages: Vec<String> = commits
.iter()
.map(|c| c.subject().to_string())
.collect();
self.llm_client.generate_tag_message(version, &commit_messages, language).await
let commit_messages: Vec<String> =
commits.iter().map(|c| c.subject().to_string()).collect();
self.llm_client
.generate_tag_message(version, &commit_messages, language)
.await
}
/// Generate changelog entry
@@ -85,8 +117,10 @@ impl ContentGenerator {
(commit_type, c.subject().to_string())
})
.collect();
self.llm_client.generate_changelog_entry(version, &typed_commits, language).await
self.llm_client
.generate_changelog_entry(version, &typed_commits, language)
.await
}
/// Generate changelog from repository
@@ -102,8 +136,9 @@ impl ContentGenerator {
} else {
repo.get_commits(50)?
};
self.generate_changelog_entry(version, &commits, language).await
self.generate_changelog_entry(version, &commits, language)
.await
}
/// Interactive commit generation with user feedback
@@ -114,49 +149,53 @@ impl ContentGenerator {
language: Language,
) -> Result<GeneratedCommit> {
use dialoguer::Select;
let diff = repo.get_staged_diff()?;
let diff = repo.get_staged_diff_sorted()?;
if diff.is_empty() {
anyhow::bail!("No staged changes");
}
// Show diff summary
let files = repo.get_staged_files()?;
println!("\nStaged files ({}):", files.len());
for file in &files {
println!("{}", file);
}
// Generate initial commit
println!("\nGenerating commit message...");
let mut generated = self.generate_commit_message(&diff, format, language).await?;
let mut generated = self
.generate_commit_message(&diff, format, language)
.await?;
loop {
println!("\n{}", "".repeat(60));
println!("Generated commit message:");
println!("{}", "".repeat(60));
println!("{}", generated.to_conventional());
println!("{}", "".repeat(60));
let options = vec![
"✓ Accept and commit",
"🔄 Regenerate",
"✏️ Edit",
"❌ Cancel",
];
let selection = Select::new()
.with_prompt("What would you like to do?")
.items(&options)
.default(0)
.interact()?;
match selection {
0 => return Ok(generated),
1 => {
println!("Regenerating...");
generated = self.generate_commit_message(&diff, format, language).await?;
generated = self
.generate_commit_message(&diff, format, language)
.await?;
}
2 => {
let edited = crate::utils::editor::edit_content(&generated.to_conventional())?;
@@ -170,7 +209,7 @@ impl ContentGenerator {
fn parse_edited_commit(&self, edited: &str, _format: CommitFormat) -> Result<GeneratedCommit> {
let parsed = crate::git::commit::parse_commit_message(edited);
Ok(GeneratedCommit {
commit_type: parsed.commit_type.unwrap_or_else(|| "chore".to_string()),
scope: parsed.scope,
@@ -207,11 +246,15 @@ pub mod fallback {
let has_code = files.iter().any(|f| {
f.ends_with(".rs") || f.ends_with(".py") || f.ends_with(".js") || f.ends_with(".ts")
});
let has_docs = files.iter().any(|f| f.ends_with(".md") || f.contains("README"));
let has_tests = files.iter().any(|f| f.contains("test") || f.contains("spec"));
let has_docs = files
.iter()
.any(|f| f.ends_with(".md") || f.contains("README"));
let has_tests = files
.iter()
.any(|f| f.contains("test") || f.contains("spec"));
if has_tests {
"test: update tests".to_string()
} else if has_docs {

View File

@@ -95,9 +95,7 @@ impl ChangelogGenerator {
ChangelogFormat::GitHubReleases => {
self.generate_github_releases(version, date, commits)
}
ChangelogFormat::Custom => {
self.generate_custom(version, date, commits)
}
ChangelogFormat::Custom => self.generate_custom(version, date, commits),
}
}
@@ -110,13 +108,13 @@ impl ChangelogGenerator {
commits: &[CommitInfo],
) -> Result<()> {
let entry = self.generate(version, date, commits)?;
let existing = if changelog_path.exists() {
fs::read_to_string(changelog_path)?
} else {
String::new()
};
let new_content = if existing.is_empty() {
format!("{}{}", CHANGELOG_HEADER, entry)
} else if existing.starts_with(CHANGELOG_HEADER) {
@@ -124,7 +122,7 @@ impl ChangelogGenerator {
} else if existing.starts_with("# Changelog") {
let lines: Vec<&str> = existing.lines().collect();
let mut header_end = 0;
for (i, line) in lines.iter().enumerate() {
if i == 0 && line.starts_with('#') {
header_end = i + 1;
@@ -134,18 +132,18 @@ impl ChangelogGenerator {
break;
}
}
let header = lines[..header_end].join("\n");
let rest = lines[header_end..].join("\n");
format!("{}\n{}\n{}", header, entry, rest)
} else {
format!("{}{}", CHANGELOG_HEADER, entry)
};
fs::write(changelog_path, new_content)
.with_context(|| format!("Failed to write changelog: {:?}", changelog_path))?;
Ok(())
}
@@ -157,10 +155,10 @@ impl ChangelogGenerator {
) -> Result<String> {
let date_str = date.format("%Y-%m-%d").to_string();
let mut output = format!("## [{}] - {}\n\n", version, date_str);
if self.group_by_type {
let grouped = self.group_commits(commits);
let _grouped = self.group_commits(commits);
// Standard categories
let categories = vec![
("Added", vec!["feat"]),
@@ -170,7 +168,7 @@ impl ChangelogGenerator {
("Fixed", vec!["fix"]),
("Security", vec!["security"]),
];
for (title, types) in &categories {
let items: Vec<&CommitInfo> = commits
.iter()
@@ -182,7 +180,7 @@ impl ChangelogGenerator {
}
})
.collect();
if !items.is_empty() {
output.push_str(&format!("### {}\n\n", title));
for commit in items {
@@ -192,13 +190,13 @@ impl ChangelogGenerator {
output.push('\n');
}
}
// Other changes
let categorized: Vec<String> = categories
.iter()
.flat_map(|(_, types)| types.iter().map(|s| s.to_string()))
.collect();
let other: Vec<&CommitInfo> = commits
.iter()
.filter(|c| {
@@ -209,7 +207,7 @@ impl ChangelogGenerator {
}
})
.collect();
if !other.is_empty() {
output.push_str("### Other\n\n");
for commit in other {
@@ -224,30 +222,30 @@ impl ChangelogGenerator {
output.push('\n');
}
}
Ok(output)
}
fn generate_github_releases(
&self,
version: &str,
_version: &str,
_date: DateTime<Utc>,
commits: &[CommitInfo],
) -> Result<String> {
let mut output = format!("## What's Changed\n\n");
let mut output = "## What's Changed\n\n".to_string();
// Group by type
let mut features = vec![];
let mut fixes = vec![];
let mut docs = vec![];
let mut other = vec![];
let mut breaking = vec![];
for commit in commits {
if commit.message.contains("BREAKING CHANGE") {
breaking.push(commit);
}
if let Some(ref t) = commit.commit_type() {
match t.as_str() {
"feat" => features.push(commit),
@@ -259,7 +257,7 @@ impl ChangelogGenerator {
other.push(commit);
}
}
if !breaking.is_empty() {
output.push_str("### ⚠ Breaking Changes\n\n");
for commit in breaking {
@@ -267,7 +265,7 @@ impl ChangelogGenerator {
}
output.push('\n');
}
if !features.is_empty() {
output.push_str("### 🚀 Features\n\n");
for commit in features {
@@ -275,7 +273,7 @@ impl ChangelogGenerator {
}
output.push('\n');
}
if !fixes.is_empty() {
output.push_str("### 🐛 Bug Fixes\n\n");
for commit in fixes {
@@ -283,7 +281,7 @@ impl ChangelogGenerator {
}
output.push('\n');
}
if !docs.is_empty() {
output.push_str("### 📚 Documentation\n\n");
for commit in docs {
@@ -291,14 +289,14 @@ impl ChangelogGenerator {
}
output.push('\n');
}
if !other.is_empty() {
output.push_str("### Other Changes\n\n");
for commit in other {
output.push_str(&self.format_commit_github(commit));
}
}
Ok(output)
}
@@ -312,7 +310,7 @@ impl ChangelogGenerator {
if !self.custom_categories.is_empty() {
let date_str = date.format("%Y-%m-%d").to_string();
let mut output = format!("## [{}] - {}\n\n", version, date_str);
for category in &self.custom_categories {
let items: Vec<&CommitInfo> = commits
.iter()
@@ -324,7 +322,7 @@ impl ChangelogGenerator {
}
})
.collect();
if !items.is_empty() {
output.push_str(&format!("### {}\n\n", category.title));
for commit in items {
@@ -334,7 +332,7 @@ impl ChangelogGenerator {
output.push('\n');
}
}
Ok(output)
} else {
// Fall back to keep-a-changelog
@@ -344,30 +342,35 @@ impl ChangelogGenerator {
fn format_commit(&self, commit: &CommitInfo) -> String {
let mut line = format!("- {}", commit.subject());
if self.include_hashes {
line.push_str(&format!(" ({})", &commit.short_id));
}
if self.include_authors {
line.push_str(&format!(" - @{}", commit.author));
}
line
}
fn format_commit_github(&self, commit: &CommitInfo) -> String {
format!("- {} by @{} in {}\n", commit.subject(), commit.author, &commit.short_id)
format!(
"- {} by @{} in {}\n",
commit.subject(),
commit.author,
&commit.short_id
)
}
fn group_commits<'a>(&self, commits: &'a [CommitInfo]) -> HashMap<String, Vec<&'a CommitInfo>> {
let mut groups: HashMap<String, Vec<&'a CommitInfo>> = HashMap::new();
for commit in commits {
let commit_type = commit.commit_type().unwrap_or_else(|| "other".to_string());
groups.entry(commit_type).or_default().push(commit);
}
groups
}
}
@@ -380,8 +383,7 @@ impl Default for ChangelogGenerator {
/// Read existing changelog
pub fn read_changelog(path: &Path) -> Result<String> {
fs::read_to_string(path)
.with_context(|| format!("Failed to read changelog: {:?}", path))
fs::read_to_string(path).with_context(|| format!("Failed to read changelog: {:?}", path))
}
/// Initialize new changelog file
@@ -389,10 +391,10 @@ pub fn init_changelog(path: &Path) -> Result<()> {
if path.exists() {
anyhow::bail!("Changelog already exists at {:?}", path);
}
fs::write(path, CHANGELOG_HEADER)
.with_context(|| format!("Failed to create changelog: {:?}", path))?;
Ok(())
}
@@ -403,7 +405,7 @@ pub fn generate_from_history(
to_ref: Option<&str>,
) -> Result<Vec<CommitInfo>> {
let to_ref = to_ref.unwrap_or("HEAD");
if let Some(from) = from_tag {
repo.get_commits_between(from, to_ref)
} else {
@@ -413,11 +415,7 @@ pub fn generate_from_history(
}
/// Update version links in changelog
pub fn update_version_links(
changelog: &str,
version: &str,
compare_url: &str,
) -> String {
pub fn update_version_links(changelog: &str, version: &str, compare_url: &str) -> String {
// Add version link at the end of changelog
format!("{}\n[{}]: {}\n", changelog, version, compare_url)
}
@@ -425,30 +423,29 @@ pub fn update_version_links(
/// Parse changelog to extract versions
pub fn parse_versions(changelog: &str) -> Vec<(String, String)> {
let mut versions = vec![];
for line in changelog.lines() {
if line.starts_with("## [") {
if let Some(start) = line.find('[') {
if let Some(end) = line.find(']') {
let version = &line[start + 1..end];
if version != "Unreleased" {
if let Some(date_start) = line.find(" - ") {
let date = &line[date_start + 3..].trim();
versions.push((version.to_string(), date.to_string()));
}
}
}
if line.starts_with("## [")
&& let Some(start) = line.find('[')
&& let Some(end) = line.find(']')
{
let version = &line[start + 1..end];
if version != "Unreleased"
&& let Some(date_start) = line.find(" - ")
{
let date = &line[date_start + 3..].trim();
versions.push((version.to_string(), date.to_string()));
}
}
}
versions
}
/// Get unreleased changes
pub fn get_unreleased_changes(repo: &GitRepo) -> Result<Vec<CommitInfo>> {
let tags = repo.get_tags()?;
if let Some(latest_tag) = tags.first() {
repo.get_commits_between(&latest_tag.name, "HEAD")
} else {

View File

@@ -1,5 +1,5 @@
use super::GitRepo;
use anyhow::{bail, Result};
use anyhow::{Result, bail};
use chrono::Local;
/// Commit builder for creating commits
@@ -119,10 +119,14 @@ impl CommitBuilder {
return Ok(msg.clone());
}
let commit_type = self.commit_type.as_ref()
let commit_type = self
.commit_type
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Commit type is required"))?;
let description = self.description.as_ref()
let description = self
.description
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Description is required"))?;
let message = match self.format {
@@ -166,45 +170,46 @@ impl CommitBuilder {
fn amend_commit(&self, repo: &GitRepo, message: &str) -> Result<()> {
use std::process::Command;
let mut args = vec!["commit", "--amend"];
if self.no_verify {
args.push("--no-verify");
}
args.push("-m");
args.push(message);
if self.sign {
args.push("-S");
}
let output = Command::new("git")
.args(&args)
.current_dir(repo.path())
.output()?;
if !output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
let error_msg = if stderr.is_empty() {
if stdout.is_empty() {
"GPG signing failed. Please check:\n\
1. GPG signing key is configured (git config --get user.signingkey)\n\
2. GPG agent is running\n\
3. You can sign commits manually (try: git commit --amend -S)".to_string()
3. You can sign commits manually (try: git commit --amend -S)"
.to_string()
} else {
stdout.to_string()
}
} else {
stderr.to_string()
};
bail!("Failed to amend commit: {}", error_msg);
}
Ok(())
}
}
@@ -219,7 +224,7 @@ impl Default for CommitBuilder {
pub fn create_date_commit_message(prefix: Option<&str>) -> String {
let now = Local::now();
let date_str = now.format("%Y-%m-%d").to_string();
match prefix {
Some(p) => format!("{}: {}", p, date_str),
None => format!("chore: update {}", date_str),
@@ -229,58 +234,65 @@ pub fn create_date_commit_message(prefix: Option<&str>) -> String {
/// Commit type suggestions based on diff
pub fn suggest_commit_type(diff: &str) -> Vec<&'static str> {
let mut suggestions = vec![];
// Check for test files
if diff.contains("test") || diff.contains("spec") || diff.contains("__tests__") {
suggestions.push("test");
}
// Check for documentation
if diff.contains("README") || diff.contains(".md") || diff.contains("docs/") {
suggestions.push("docs");
}
// Check for configuration files
if diff.contains("config") || diff.contains(".json") || diff.contains(".yaml") || diff.contains(".toml") {
if diff.contains("config")
|| diff.contains(".json")
|| diff.contains(".yaml")
|| diff.contains(".toml")
{
suggestions.push("chore");
}
// Check for dependencies
if diff.contains("Cargo.toml") || diff.contains("package.json") || diff.contains("requirements.txt") {
if diff.contains("Cargo.toml")
|| diff.contains("package.json")
|| diff.contains("requirements.txt")
{
suggestions.push("build");
}
// Check for CI
if diff.contains(".github/") || diff.contains(".gitlab-") || diff.contains("Jenkinsfile") {
suggestions.push("ci");
}
// Default suggestions
if suggestions.is_empty() {
suggestions.extend(&["feat", "fix", "refactor"]);
}
suggestions
}
/// Parse existing commit message
pub fn parse_commit_message(message: &str) -> ParsedCommit {
let lines: Vec<&str> = message.lines().collect();
if lines.is_empty() {
return ParsedCommit::default();
}
let first_line = lines[0];
// Try to parse as conventional commit
if let Some(colon_pos) = first_line.find(':') {
let type_part = &first_line[..colon_pos];
let description = first_line[colon_pos + 1..].trim();
let breaking = type_part.ends_with('!');
let type_part = type_part.trim_end_matches('!');
let (commit_type, scope) = if let Some(open) = type_part.find('(') {
if let Some(close) = type_part.find(')') {
let t = &type_part[..open];
@@ -292,42 +304,51 @@ pub fn parse_commit_message(message: &str) -> ParsedCommit {
} else {
(Some(type_part.to_string()), None)
};
// Extract body and footer
let mut body_lines = vec![];
let mut footer_lines = vec![];
let mut in_footer = false;
for line in &lines[1..] {
if line.trim().is_empty() {
continue;
}
if line.starts_with("BREAKING CHANGE:") ||
line.starts_with("Closes") ||
line.starts_with("Fixes") ||
line.starts_with("Refs") ||
line.starts_with("Co-authored-by:") {
if line.starts_with("BREAKING CHANGE:")
|| line.starts_with("Closes")
|| line.starts_with("Fixes")
|| line.starts_with("Refs")
|| line.starts_with("Co-authored-by:")
{
in_footer = true;
}
if in_footer {
footer_lines.push(line.to_string());
} else {
body_lines.push(line.to_string());
}
}
return ParsedCommit {
commit_type,
scope,
description: Some(description.to_string()),
body: if body_lines.is_empty() { None } else { Some(body_lines.join("\n")) },
footer: if footer_lines.is_empty() { None } else { Some(footer_lines.join("\n")) },
body: if body_lines.is_empty() {
None
} else {
Some(body_lines.join("\n"))
},
footer: if footer_lines.is_empty() {
None
} else {
Some(footer_lines.join("\n"))
},
breaking,
};
}
// Non-conventional commit
ParsedCommit {
description: Some(first_line.to_string()),
@@ -351,7 +372,7 @@ impl ParsedCommit {
pub fn to_message(&self, format: crate::config::CommitFormat) -> String {
let commit_type = self.commit_type.as_deref().unwrap_or("chore");
let description = self.description.as_deref().unwrap_or("update");
match format {
crate::config::CommitFormat::Conventional => {
crate::utils::formatter::format_conventional_commit(

View File

@@ -1,54 +1,49 @@
use anyhow::{bail, Context, Result};
use git2::{Repository, Signature, StatusOptions, Config, Oid, ObjectType};
use std::path::{Path, PathBuf, Component};
use anyhow::{Context, Result, bail};
use git2::{Config, ObjectType, Oid, Repository, Signature, StatusOptions};
use std::collections::HashMap;
use tempfile;
use std::path::{Component, Path, PathBuf};
pub mod changelog;
pub mod commit;
pub mod tag;
#[cfg(target_os = "windows")]
use std::os::windows::ffi::OsStringExt;
fn normalize_path_for_git2(path: &Path) -> PathBuf {
let mut normalized = path.to_path_buf();
#[cfg(target_os = "windows")]
{
let path_str = path.to_string_lossy();
if path_str.starts_with(r"\\?\") {
if let Some(stripped) = path_str.strip_prefix(r"\\?\") {
normalized = PathBuf::from(stripped);
}
if path_str.starts_with(r"\\?\")
&& let Some(stripped) = path_str.strip_prefix(r"\\?\")
{
normalized = PathBuf::from(stripped);
}
if path_str.starts_with(r"\\?\UNC\") {
if let Some(stripped) = path_str.strip_prefix(r"\\?\UNC\") {
normalized = PathBuf::from(format!(r"\\{}", stripped));
}
if path_str.starts_with(r"\\?\UNC\")
&& let Some(stripped) = path_str.strip_prefix(r"\\?\UNC\")
{
normalized = PathBuf::from(format!(r"\\{}", stripped));
}
}
normalized
}
fn get_absolute_path<P: AsRef<Path>>(path: P) -> Result<PathBuf> {
let path = path.as_ref();
if path.is_absolute() {
return Ok(normalize_path_for_git2(path));
}
let current_dir = std::env::current_dir()
.with_context(|| "Failed to get current directory")?;
let current_dir = std::env::current_dir().with_context(|| "Failed to get current directory")?;
let absolute = current_dir.join(path);
Ok(normalize_path_for_git2(&absolute))
}
fn resolve_path_without_canonicalize(path: &Path) -> PathBuf {
let mut components = Vec::new();
for component in path.components() {
match component {
Component::ParentDir => {
@@ -62,62 +57,62 @@ fn resolve_path_without_canonicalize(path: &Path) -> PathBuf {
_ => components.push(component),
}
}
let mut result = PathBuf::new();
for component in components {
result.push(component.as_os_str());
}
normalize_path_for_git2(&result)
}
fn try_open_repo_with_git2(path: &Path) -> Result<Repository> {
let normalized = normalize_path_for_git2(path);
let discover_opts = git2::RepositoryOpenFlags::empty();
let ceiling_dirs: [&str; 0] = [];
let repo = Repository::open_ext(&normalized, discover_opts, &ceiling_dirs)
let repo = Repository::open_ext(&normalized, discover_opts, ceiling_dirs)
.or_else(|_| Repository::discover(&normalized))
.or_else(|_| Repository::open(&normalized));
repo.map_err(|e| anyhow::anyhow!("git2 failed: {}", e))
}
fn try_open_repo_with_git_cli(path: &Path) -> Result<Repository> {
let output = std::process::Command::new("git")
.args(&["rev-parse", "--show-toplevel"])
.args(["rev-parse", "--show-toplevel"])
.current_dir(path)
.output()
.context("Failed to execute git command")?;
if !output.status.success() {
bail!("git CLI failed to find repository");
}
let stdout = String::from_utf8_lossy(&output.stdout);
let git_root = stdout.trim();
if git_root.is_empty() {
bail!("git CLI returned empty path");
}
let git_root_path = PathBuf::from(git_root);
let normalized = normalize_path_for_git2(&git_root_path);
Repository::open(&normalized)
.with_context(|| format!("Failed to open repo from git CLI path: {:?}", normalized))
}
fn diagnose_repo_issue(path: &Path) -> String {
let mut issues = Vec::new();
if !path.exists() {
issues.push(format!("Path does not exist: {:?}", path));
} else if !path.is_dir() {
issues.push(format!("Path is not a directory: {:?}", path));
}
let git_dir = path.join(".git");
if git_dir.exists() {
if git_dir.is_dir() {
@@ -133,7 +128,7 @@ fn diagnose_repo_issue(path: &Path) -> String {
}
} else {
issues.push("No .git found in current directory".to_string());
let mut current = path;
let mut depth = 0;
while let Some(parent) = current.parent() {
@@ -149,7 +144,7 @@ fn diagnose_repo_issue(path: &Path) -> String {
current = parent;
}
}
#[cfg(target_os = "windows")]
{
let path_str = path.to_string_lossy();
@@ -160,11 +155,11 @@ fn diagnose_repo_issue(path: &Path) -> String {
issues.push("WARNING: Path has mixed path separators".to_string());
}
}
if let Ok(current_dir) = std::env::current_dir() {
issues.push(format!("Current working directory: {:?}", current_dir));
}
issues.join("\n ")
}
@@ -177,17 +172,15 @@ pub struct GitRepo {
impl GitRepo {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
let absolute_path = get_absolute_path(path)?;
let resolved_path = resolve_path_without_canonicalize(&absolute_path);
let repo = try_open_repo_with_git2(&resolved_path)
.or_else(|git2_err| {
try_open_repo_with_git_cli(&resolved_path)
.map_err(|cli_err| {
let diagnosis = diagnose_repo_issue(&resolved_path);
anyhow::anyhow!(
"Failed to open git repository:\n\
let repo = try_open_repo_with_git2(&resolved_path).or_else(|git2_err| {
try_open_repo_with_git_cli(&resolved_path).map_err(|cli_err| {
let diagnosis = diagnose_repo_issue(&resolved_path);
anyhow::anyhow!(
"Failed to open git repository:\n\
\n\
=== git2 Error ===\n {}\n\
\n\
@@ -200,17 +193,20 @@ impl GitRepo {
2. Run: git status (to verify git works)\n\
3. Run: git config --global --add safe.directory \"*\"\n\
4. Check file permissions",
git2_err, cli_err, diagnosis
)
})
})?;
let repo_path = repo.workdir()
git2_err,
cli_err,
diagnosis
)
})
})?;
let repo_path = repo
.workdir()
.map(|p| p.to_path_buf())
.unwrap_or_else(|| resolved_path.clone());
let config = repo.config().ok();
Ok(Self {
repo,
path: normalize_path_for_git2(&repo_path),
@@ -251,7 +247,11 @@ impl GitRepo {
pub fn get_user_name(&self) -> Result<String> {
self.get_config("user.name")?
.or_else(|| std::env::var("GIT_AUTHOR_NAME").ok())
.ok_or_else(|| anyhow::anyhow!("User name not configured. Set it with: git config user.name \"Your Name\""))
.ok_or_else(|| {
anyhow::anyhow!(
"User name not configured. Set it with: git config user.name \"Your Name\""
)
})
}
/// Get the configured user email
@@ -263,7 +263,8 @@ impl GitRepo {
/// Get the configured GPG signing key
pub fn get_signing_key(&self) -> Result<Option<String>> {
Ok(self.get_config("user.signingkey")?
Ok(self
.get_config("user.signingkey")?
.or_else(|| std::env::var("GIT_SIGNING_KEY").ok()))
}
@@ -290,13 +291,9 @@ impl GitRepo {
if let Some(program) = self.get_config("gpg.program")? {
return Ok(program);
}
let default_gpg = if cfg!(windows) {
"gpg.exe"
} else {
"gpg"
};
let default_gpg = if cfg!(windows) { "gpg.exe" } else { "gpg" };
Ok(default_gpg.to_string())
}
@@ -304,10 +301,13 @@ impl GitRepo {
pub fn create_signature(&self) -> Result<Signature<'_>> {
let name = self.get_user_name()?;
let email = self.get_user_email()?;
let time = git2::Time::new(std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64, 0);
let time = git2::Time::new(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
0,
);
Signature::new(&name, &email, &time).map_err(Into::into)
}
@@ -332,7 +332,7 @@ impl GitRepo {
pub fn get_staged_diff(&self) -> Result<String> {
// Use git CLI to get staged diff for better compatibility
let output = std::process::Command::new("git")
.args(&["diff", "--cached"])
.args(["diff", "--cached"])
.current_dir(&self.path)
.output()
.with_context(|| "Failed to get staged diff with git command")?;
@@ -346,6 +346,128 @@ impl GitRepo {
Ok(diff_text)
}
/// Get staged diff with files sorted by importance
/// Important files (source code) come first, then config files like Cargo.toml,
/// then lock files like Cargo.lock
pub fn get_staged_diff_sorted(&self) -> Result<String> {
let diff = self.get_staged_diff()?;
if diff.is_empty() {
return Ok(diff);
}
let mut file_diffs = Vec::new();
let mut current_file_diff = String::new();
let mut current_file = String::new();
for line in diff.lines() {
if line.starts_with("diff --git") {
// Save previous file diff if any
if !current_file_diff.is_empty() && !current_file.is_empty() {
file_diffs.push((current_file.clone(), current_file_diff.clone()));
}
current_file = extract_file_from_diff_line(line);
current_file_diff = format!("{}\n", line);
} else {
current_file_diff.push_str(line);
current_file_diff.push('\n');
}
}
// Add the last file diff
if !current_file_diff.is_empty() && !current_file.is_empty() {
file_diffs.push((current_file, current_file_diff));
}
// Sort by file importance
file_diffs.sort_by(|a, b| {
let score_a = file_importance_score(&a.0);
let score_b = file_importance_score(&b.0);
score_b.cmp(&score_a) // Descending order
});
// Combine sorted diffs
let sorted_diff: String = file_diffs.into_iter().map(|(_, diff)| diff).collect();
Ok(sorted_diff)
}
}
/// Extract filename from diff --git line
fn extract_file_from_diff_line(line: &str) -> String {
// Format: "diff --git a/path/to/file b/path/to/file"
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 3 {
// Return the second path (after b/)
if let Some(path) = parts[2].strip_prefix("b/") {
return path.to_string();
}
// Fallback to first path (after a/)
if let Some(path) = parts[1].strip_prefix("a/") {
return path.to_string();
}
}
line.to_string()
}
/// Calculate file importance score
/// Higher score = more important
fn file_importance_score(filename: &str) -> i32 {
// Priority list for important file types
let important_extensions = [
".rs", ".py", ".js", ".ts", ".tsx", ".jsx", ".go", ".java", ".cpp", ".c", ".rust", ".vue",
".svelte", ".html", ".css", ".scss", ".sass", ".less",
];
// Config files that are important but less than source code
let config_files = [
"Cargo.toml",
"package.json",
"go.mod",
"go.sum",
"pom.xml",
"Makefile",
"CMakeLists.txt",
"build.gradle",
"gradle.properties",
];
// Lock files - lowest priority
let lock_files = [
"Cargo.lock",
"package-lock.json",
"yarn.lock",
"pnpm-lock.yaml",
"Gemfile.lock",
"composer.lock",
];
// Check lock files first (lowest priority)
for lock in lock_files.iter() {
if filename.ends_with(lock) {
return 1;
}
}
// Check config files (medium priority)
for config in config_files.iter() {
if filename.ends_with(config) {
return 2;
}
}
// Check important source files (highest priority)
for ext in important_extensions.iter() {
if filename.ends_with(ext) {
return 3;
}
}
// Default priority for other files
2
}
impl GitRepo {
/// Get unstaged diff
pub fn get_unstaged_diff(&self) -> Result<String> {
let diff = self.repo.diff_index_to_workdir(None, None)?;
@@ -390,18 +512,21 @@ impl GitRepo {
/// Get list of staged files
pub fn get_staged_files(&self) -> Result<Vec<String>> {
let statuses = self.repo.statuses(Some(
StatusOptions::new()
.include_untracked(false),
))?;
let statuses = self
.repo
.statuses(Some(StatusOptions::new().include_untracked(false)))?;
let mut files = vec![];
for entry in statuses.iter() {
let status = entry.status();
if status.is_index_new() || status.is_index_modified() || status.is_index_deleted() || status.is_index_renamed() || status.is_index_typechange() {
if let Some(path) = entry.path() {
files.push(path.to_string());
}
if (status.is_index_new()
|| status.is_index_modified()
|| status.is_index_deleted()
|| status.is_index_renamed()
|| status.is_index_typechange())
&& let Some(path) = entry.path()
{
files.push(path.to_string());
}
}
@@ -431,7 +556,7 @@ impl GitRepo {
pub fn stage_all(&self) -> Result<()> {
// Use git command for reliable staging (handles all edge cases)
let output = std::process::Command::new("git")
.args(&["add", "-A"])
.args(["add", "-A"])
.current_dir(&self.path)
.output()
.with_context(|| "Failed to stage changes with git command")?;
@@ -515,27 +640,28 @@ impl GitRepo {
std::fs::write(temp_file.path(), message)?;
let output = std::process::Command::new("git")
.args(&["commit", "-S", "-F", temp_file.path().to_str().unwrap()])
.args(["commit", "-S", "-F", temp_file.path().to_str().unwrap()])
.current_dir(&self.path)
.output()?;
if !output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
let error_msg = if stderr.is_empty() {
if stdout.is_empty() {
"GPG signing failed. Please check:\n\
1. GPG signing key is configured (git config --get user.signingkey)\n\
2. GPG agent is running\n\
3. You can sign commits manually (try: git commit -S -m 'test')".to_string()
3. You can sign commits manually (try: git commit -S -m 'test')"
.to_string()
} else {
stdout.to_string()
}
} else {
stderr.to_string()
};
bail!("Failed to create signed commit: {}", error_msg);
}
@@ -548,7 +674,8 @@ impl GitRepo {
let head = self.repo.head()?;
if head.is_branch() {
let name = head.shorthand()
let name = head
.shorthand()
.ok_or_else(|| anyhow::anyhow!("Invalid branch name"))?;
Ok(name.to_string())
} else {
@@ -559,7 +686,8 @@ impl GitRepo {
/// Get current commit hash (short)
pub fn current_commit_short(&self) -> Result<String> {
let head = self.repo.head()?;
let oid = head.target()
let oid = head
.target()
.ok_or_else(|| anyhow::anyhow!("No target for HEAD"))?;
Ok(oid.to_string()[..8].to_string())
}
@@ -567,7 +695,8 @@ impl GitRepo {
/// Get current commit hash (full)
pub fn current_commit(&self) -> Result<String> {
let head = self.repo.head()?;
let oid = head.target()
let oid = head
.target()
.ok_or_else(|| anyhow::anyhow!("No target for HEAD"))?;
Ok(oid.to_string())
}
@@ -642,12 +771,16 @@ impl GitRepo {
name: name.to_string(),
target: oid.to_string(),
message: commit.message().unwrap_or("").to_string(),
time: commit.time().seconds(),
});
}
true
})?;
// Sort tags by time (newest first)
tags.sort_by(|a, b| b.time.cmp(&a.time));
Ok(tags)
}
@@ -662,13 +795,7 @@ impl GitRepo {
if sign {
self.create_signed_tag_with_git2(name, msg, &sig, target.id())?;
} else {
self.repo.tag(
name,
target.as_object(),
&sig,
msg,
false,
)?;
self.repo.tag(name, target.as_object(), &sig, msg, false)?;
}
} else {
self.repo.tag(
@@ -684,9 +811,15 @@ impl GitRepo {
}
/// Create signed tag using git CLI
fn create_signed_tag_with_git2(&self, name: &str, message: &str, _signature: &Signature, _target_id: Oid) -> Result<()> {
fn create_signed_tag_with_git2(
&self,
name: &str,
message: &str,
_signature: &Signature,
_target_id: Oid,
) -> Result<()> {
let output = std::process::Command::new("git")
.args(&["tag", "-s", name, "-m", message])
.args(["tag", "-s", name, "-m", message])
.current_dir(&self.path)
.output()?;
@@ -699,7 +832,12 @@ impl GitRepo {
}
/// Create GPG signature for arbitrary content
fn create_gpg_signature_for_content(&self, _content: &str, _gpg_program: &str, _signing_key: &str) -> Result<String> {
fn create_gpg_signature_for_content(
&self,
_content: &str,
_gpg_program: &str,
_signing_key: &str,
) -> Result<String> {
Ok(String::new())
}
@@ -712,7 +850,7 @@ impl GitRepo {
/// Push to remote
pub fn push(&self, remote: &str, refspec: &str) -> Result<()> {
let output = std::process::Command::new("git")
.args(&["push", remote, refspec])
.args(["push", remote, refspec])
.current_dir(&self.path)
.output()?;
@@ -727,7 +865,8 @@ impl GitRepo {
/// Get remote URL
pub fn get_remote_url(&self, remote: &str) -> Result<String> {
let remote_obj = self.repo.find_remote(remote)?;
let url = remote_obj.url()
let url = remote_obj
.url()
.ok_or_else(|| anyhow::anyhow!("Remote has no URL"))?;
Ok(url.to_string())
}
@@ -741,7 +880,7 @@ impl GitRepo {
pub fn status_summary(&self) -> Result<StatusSummary> {
// Use git CLI for more reliable status detection
let output = std::process::Command::new("git")
.args(&["status", "--porcelain"])
.args(["status", "--porcelain"])
.current_dir(&self.path)
.output()
.with_context(|| "Failed to get status with git command")?;
@@ -778,9 +917,10 @@ impl GitRepo {
}
// Conflicted files (both columns are U or DD, AA, etc.)
if (index_status == 'U' || worktree_status == 'U') ||
(index_status == 'A' && worktree_status == 'A') ||
(index_status == 'D' && worktree_status == 'D') {
if (index_status == 'U' || worktree_status == 'U')
|| (index_status == 'A' && worktree_status == 'A')
|| (index_status == 'D' && worktree_status == 'D')
{
conflicted += 1;
}
}
@@ -832,6 +972,7 @@ pub struct TagInfo {
pub name: String,
pub target: String,
pub message: String,
pub time: i64,
}
/// Repository status summary
@@ -870,52 +1011,51 @@ impl StatusSummary {
pub fn find_repo<P: AsRef<Path>>(start_path: P) -> Result<GitRepo> {
let start_path = start_path.as_ref();
let absolute_start = get_absolute_path(start_path)?;
let resolved_start = resolve_path_without_canonicalize(&absolute_start);
if let Ok(repo) = GitRepo::open(&resolved_start) {
return Ok(repo);
}
let mut current = resolved_start.as_path();
let mut attempted_paths = vec![current.to_string_lossy().to_string()];
let max_depth = 50;
let mut depth = 0;
while let Some(parent) = current.parent() {
depth += 1;
if depth > max_depth {
break;
}
attempted_paths.push(parent.to_string_lossy().to_string());
if let Ok(repo) = GitRepo::open(parent) {
return Ok(repo);
}
current = parent;
}
if let Ok(output) = std::process::Command::new("git")
.args(&["rev-parse", "--show-toplevel"])
.args(["rev-parse", "--show-toplevel"])
.current_dir(&resolved_start)
.output()
&& output.status.success()
{
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
let git_root = stdout.trim();
if !git_root.is_empty() {
if let Ok(repo) = GitRepo::open(git_root) {
return Ok(repo);
}
}
let stdout = String::from_utf8_lossy(&output.stdout);
let git_root = stdout.trim();
if !git_root.is_empty()
&& let Ok(repo) = GitRepo::open(git_root)
{
return Ok(repo);
}
}
let diagnosis = diagnose_repo_issue(&resolved_start);
bail!(
"No git repository found.\n\
\n\
@@ -1037,6 +1177,105 @@ impl<'a> GitConfigHelper<'a> {
}
}
/// Configuration source indicator
#[derive(Debug, Clone, PartialEq)]
pub enum ConfigSource {
Local,
Global,
NotSet,
}
impl std::fmt::Display for ConfigSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigSource::Local => write!(f, "local"),
ConfigSource::Global => write!(f, "global"),
ConfigSource::NotSet => write!(f, "not set"),
}
}
}
/// Single configuration entry with source information
#[derive(Debug, Clone)]
pub struct ConfigEntry {
pub value: Option<String>,
pub source: ConfigSource,
pub local_value: Option<String>,
pub global_value: Option<String>,
}
impl ConfigEntry {
pub fn new(local: Option<String>, global: Option<String>) -> Self {
let (value, source) = match (&local, &global) {
(Some(_), _) => (local.clone(), ConfigSource::Local),
(None, Some(_)) => (global.clone(), ConfigSource::Global),
(None, None) => (None, ConfigSource::NotSet),
};
Self {
value,
source,
local_value: local,
global_value: global,
}
}
pub fn is_set(&self) -> bool {
self.value.is_some()
}
pub fn is_local(&self) -> bool {
self.source == ConfigSource::Local
}
pub fn is_global(&self) -> bool {
self.source == ConfigSource::Global
}
}
/// Merged user configuration with local/global source tracking
#[derive(Debug, Clone)]
pub struct MergedUserConfig {
pub name: ConfigEntry,
pub email: ConfigEntry,
pub signing_key: ConfigEntry,
pub ssh_command: ConfigEntry,
pub commit_gpgsign: ConfigEntry,
pub tag_gpgsign: ConfigEntry,
}
impl MergedUserConfig {
pub fn from_repo(repo: &Repository) -> Result<Self> {
let local_config = repo.config().ok();
let global_config = git2::Config::open_default().ok();
let get_entry = |key: &str| -> ConfigEntry {
let local = local_config.as_ref().and_then(|c| c.get_string(key).ok());
let global = global_config.as_ref().and_then(|c| c.get_string(key).ok());
ConfigEntry::new(local, global)
};
Ok(Self {
name: get_entry("user.name"),
email: get_entry("user.email"),
signing_key: get_entry("user.signingkey"),
ssh_command: get_entry("core.sshCommand"),
commit_gpgsign: get_entry("commit.gpgsign"),
tag_gpgsign: get_entry("tag.gpgsign"),
})
}
pub fn is_complete(&self) -> bool {
self.name.is_set() && self.email.is_set()
}
pub fn has_local_overrides(&self) -> bool {
self.name.is_local()
|| self.email.is_local()
|| self.signing_key.is_local()
|| self.ssh_command.is_local()
}
}
/// User configuration for git
#[derive(Debug, Clone)]
pub struct UserConfig {
@@ -1060,23 +1299,38 @@ impl UserConfig {
diffs.push(ConfigDiff {
key: "user.name".to_string(),
left: self.name.clone().unwrap_or_else(|| "<not set>".to_string()),
right: other.name.clone().unwrap_or_else(|| "<not set>".to_string()),
right: other
.name
.clone()
.unwrap_or_else(|| "<not set>".to_string()),
});
}
if self.email != other.email {
diffs.push(ConfigDiff {
key: "user.email".to_string(),
left: self.email.clone().unwrap_or_else(|| "<not set>".to_string()),
right: other.email.clone().unwrap_or_else(|| "<not set>".to_string()),
left: self
.email
.clone()
.unwrap_or_else(|| "<not set>".to_string()),
right: other
.email
.clone()
.unwrap_or_else(|| "<not set>".to_string()),
});
}
if self.signing_key != other.signing_key {
diffs.push(ConfigDiff {
key: "user.signingkey".to_string(),
left: self.signing_key.clone().unwrap_or_else(|| "<not set>".to_string()),
right: other.signing_key.clone().unwrap_or_else(|| "<not set>".to_string()),
left: self
.signing_key
.clone()
.unwrap_or_else(|| "<not set>".to_string()),
right: other
.signing_key
.clone()
.unwrap_or_else(|| "<not set>".to_string()),
});
}

View File

@@ -1,5 +1,5 @@
use super::GitRepo;
use anyhow::{bail, Result};
use anyhow::{Result, bail};
use semver::Version;
/// Tag builder for creating tags
@@ -69,19 +69,19 @@ impl TagBuilder {
/// Build tag message
pub fn build_message(&self) -> Result<String> {
let message = self.message.as_ref()
.cloned()
.unwrap_or_else(|| {
let name = self.name.as_deref().unwrap_or("unknown");
format!("Release {}", name)
});
let message = self.message.as_ref().cloned().unwrap_or_else(|| {
let name = self.name.as_deref().unwrap_or("unknown");
format!("Release {}", name)
});
Ok(message)
}
/// Execute tag creation
pub fn execute(&self, repo: &GitRepo) -> Result<()> {
let name = self.name.as_ref()
let name = self
.name
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Tag name is required"))?;
if !self.force {
@@ -105,10 +105,10 @@ impl TagBuilder {
/// Execute and push tag
pub fn execute_and_push(&self, repo: &GitRepo, remote: &str) -> Result<()> {
self.execute(repo)?;
let name = self.name.as_ref().unwrap();
repo.push(remote, &format!("refs/tags/{}", name))?;
Ok(())
}
}
@@ -136,7 +136,10 @@ impl VersionBump {
"minor" => Ok(Self::Minor),
"patch" => Ok(Self::Patch),
"prerelease" | "pre" => Ok(Self::Prerelease),
_ => bail!("Invalid version bump: {}. Use: major, minor, patch, prerelease", s),
_ => bail!(
"Invalid version bump: {}. Use: major, minor, patch, prerelease",
s
),
}
}
@@ -149,7 +152,7 @@ impl VersionBump {
/// Get latest version tag from repository
pub fn get_latest_version(repo: &GitRepo, prefix: &str) -> Result<Option<Version>> {
let tags = repo.get_tags()?;
let mut versions: Vec<Version> = tags
.iter()
.filter_map(|t| {
@@ -158,9 +161,9 @@ pub fn get_latest_version(repo: &GitRepo, prefix: &str) -> Result<Option<Version
Version::parse(version_str).ok()
})
.collect();
versions.sort_by(|a, b| b.cmp(a)); // Descending order
Ok(versions.into_iter().next())
}
@@ -183,14 +186,17 @@ pub fn suggest_version_bump(commits: &[super::CommitInfo]) -> VersionBump {
let mut has_breaking = false;
let mut has_feature = false;
let mut has_fix = false;
for commit in commits {
let msg = commit.message.to_lowercase();
if msg.contains("breaking change") || msg.contains("breaking-change") || msg.contains("breaking_change") {
if msg.contains("breaking change")
|| msg.contains("breaking-change")
|| msg.contains("breaking_change")
{
has_breaking = true;
}
if let Some(commit_type) = commit.commit_type() {
match commit_type.as_str() {
"feat" => has_feature = true,
@@ -199,7 +205,7 @@ pub fn suggest_version_bump(commits: &[super::CommitInfo]) -> VersionBump {
}
}
}
if has_breaking {
VersionBump::Major
} else if has_feature {
@@ -214,20 +220,20 @@ pub fn suggest_version_bump(commits: &[super::CommitInfo]) -> VersionBump {
/// Generate tag message from commits
pub fn generate_tag_message(version: &str, commits: &[super::CommitInfo]) -> String {
let mut message = format!("Release {}\n\n", version);
// Group commits by type
let mut features = vec![];
let mut fixes = vec![];
let mut other = vec![];
let mut breaking = vec![];
for commit in commits {
let subject = commit.subject();
if commit.message.contains("BREAKING CHANGE") {
breaking.push(subject.to_string());
}
if let Some(commit_type) = commit.commit_type() {
match commit_type.as_str() {
"feat" => features.push(subject.to_string()),
@@ -238,7 +244,7 @@ pub fn generate_tag_message(version: &str, commits: &[super::CommitInfo]) -> Str
other.push(subject.to_string());
}
}
// Build message
if !breaking.is_empty() {
message.push_str("## Breaking Changes\n\n");
@@ -247,7 +253,7 @@ pub fn generate_tag_message(version: &str, commits: &[super::CommitInfo]) -> Str
}
message.push('\n');
}
if !features.is_empty() {
message.push_str("## Features\n\n");
for item in &features {
@@ -255,7 +261,7 @@ pub fn generate_tag_message(version: &str, commits: &[super::CommitInfo]) -> Str
}
message.push('\n');
}
if !fixes.is_empty() {
message.push_str("## Bug Fixes\n\n");
for item in &fixes {
@@ -263,36 +269,36 @@ pub fn generate_tag_message(version: &str, commits: &[super::CommitInfo]) -> Str
}
message.push('\n');
}
if !other.is_empty() {
message.push_str("## Other Changes\n\n");
for item in &other {
message.push_str(&format!("- {}\n", item));
}
}
message
}
/// Tag deletion helper
pub fn delete_tag(repo: &GitRepo, name: &str, remote: Option<&str>) -> Result<()> {
repo.delete_tag(name)?;
if let Some(remote) = remote {
use std::process::Command;
let refspec = format!(":refs/tags/{}", name);
let output = Command::new("git")
.args(&["push", remote, &refspec])
.args(["push", remote, &refspec])
.current_dir(repo.path())
.output()?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
bail!("Failed to delete remote tag: {}", stderr);
}
}
Ok(())
}
@@ -303,7 +309,7 @@ pub fn list_tags(
limit: Option<usize>,
) -> Result<Vec<super::TagInfo>> {
let tags = repo.get_tags()?;
let filtered: Vec<_> = tags
.into_iter()
.filter(|t| {
@@ -314,7 +320,7 @@ pub fn list_tags(
}
})
.collect();
if let Some(limit) = limit {
Ok(filtered.into_iter().take(limit).collect())
} else {

View File

@@ -267,7 +267,9 @@ impl Messages {
Language::Chinese => "没有可提交的更改。工作树是干净的。",
Language::Japanese => "コミットする変更がありません。作業ツリーはクリーンです。",
Language::Korean => "커밋할 변경 사항이 없습니다. 작업 트리가 깨끗합니다.",
Language::Spanish => "No hay cambios para hacer commit. El árbol de trabajo está limpio.",
Language::Spanish => {
"No hay cambios para hacer commit. El árbol de trabajo está limpio."
}
Language::French => "Aucun changement à commiter. L'arbre de travail est propre.",
Language::German => "Keine Änderungen zum Committen. Arbeitsbaum ist sauber.",
}
@@ -289,11 +291,19 @@ impl Messages {
match self.language {
Language::English => "No files staged. Auto-staging all changes...",
Language::Chinese => "没有暂存文件。自动暂存所有更改...",
Language::Japanese => "ステージされたファイルがありません。すべての変更を自動ステージ中...",
Language::Japanese => {
"ステージされたファイルがありません。すべての変更を自動ステージ中..."
}
Language::Korean => "스테이징된 파일이 없습니다. 모든 변경 사항을 자동 스테이징 중...",
Language::Spanish => "No hay archivos preparados. Preparando automáticamente todos los cambios...",
Language::French => "Aucun fichier indexé. Indexation automatique de tous les changements...",
Language::German => "Keine Dateien bereitgestellt. Alle Änderungen werden automatisch bereitgestellt...",
Language::Spanish => {
"No hay archivos preparados. Preparando automáticamente todos los cambios..."
}
Language::French => {
"Aucun fichier indexé. Indexation automatique de tous les changements..."
}
Language::German => {
"Keine Dateien bereitgestellt. Alle Änderungen werden automatisch bereitgestellt..."
}
}
}
@@ -359,12 +369,23 @@ impl Messages {
pub fn ai_generating_tag(&self, count: usize) -> String {
match self.language {
Language::English => format!("🤖 AI is generating tag message from {} commits...", count),
Language::English => {
format!("🤖 AI is generating tag message from {} commits...", count)
}
Language::Chinese => format!("🤖 AI 正在从 {} 个提交生成标签消息...", count),
Language::Japanese => format!("🤖 AIが{}個のコミットからタグメッセージを生成しています...", count),
Language::Japanese => format!(
"🤖 AIが{}個のコミットからタグメッセージを生成しています...",
count
),
Language::Korean => format!("🤖 AI가 {}개의 커밋에서 태그 메시지를 생성 중...", count),
Language::Spanish => format!("🤖 La IA está generando mensaje de etiqueta desde {} commits...", count),
Language::French => format!("🤖 L'IA génère le message dtiquette à partir de {} commits...", count),
Language::Spanish => format!(
"🤖 La IA está generando mensaje de etiqueta desde {} commits...",
count
),
Language::French => format!(
"🤖 L'IA génère le message d'étiquette à partir de {} commits...",
count
),
Language::German => format!("🤖 KI generiert Tag-Nachricht aus {} Commits...", count),
}
}

View File

@@ -7,7 +7,11 @@ pub struct Translator {
}
impl Translator {
pub fn new(language: Language, keep_types_english: bool, keep_changelog_types_english: bool) -> Self {
pub fn new(
language: Language,
keep_types_english: bool,
keep_changelog_types_english: bool,
) -> Self {
Self {
language,
keep_types_english,
@@ -227,7 +231,11 @@ pub fn translate_commit_type(commit_type: &str, language: Language, keep_english
translator.translate_commit_type(commit_type)
}
pub fn translate_changelog_category(category: &str, language: Language, keep_english: bool) -> String {
pub fn translate_changelog_category(
category: &str,
language: Language,
keep_english: bool,
) -> String {
let translator = Translator::new(language, true, keep_english);
translator.translate_changelog_category(category)
}

View File

@@ -1,7 +1,9 @@
use super::{create_http_client, LlmProvider};
use anyhow::{bail, Context, Result};
use super::thinking::ThinkingStateManager;
use super::{LlmProvider, create_http_client};
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
/// Anthropic Claude API client
@@ -9,6 +11,12 @@ pub struct AnthropicClient {
api_key: String,
model: String,
client: reqwest::Client,
thinking_enabled: bool,
thinking_budget_tokens: u32,
max_tokens: u32,
temperature: f32,
top_p: Option<f32>,
thinking_state: Option<Arc<ThinkingStateManager>>,
}
#[derive(Debug, Serialize)]
@@ -17,24 +25,59 @@ struct MessagesRequest {
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
messages: Vec<AnthropicMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
system: Option<Vec<SystemContent>>,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<ThinkingConfig>,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Clone)]
struct SystemContent {
#[serde(rename = "type")]
content_type: String,
text: String,
}
#[derive(Debug, Serialize)]
struct ThinkingConfig {
#[serde(rename = "type")]
thinking_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
budget_tokens: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct AnthropicMessage {
role: String,
content: String,
content: AnthropicContent,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(untagged)]
enum AnthropicContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct ContentBlock {
#[serde(rename = "type")]
content_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct MessagesResponse {
content: Vec<ContentBlock>,
content: Vec<ResponseContentBlock>,
}
#[derive(Debug, Deserialize)]
struct ContentBlock {
struct ResponseContentBlock {
#[serde(rename = "type")]
content_type: String,
text: String,
@@ -52,31 +95,112 @@ struct AnthropicError {
message: String,
}
// --- Streaming SSE event structures ---
#[derive(Debug, Deserialize)]
struct SseEvent {
#[serde(rename = "type")]
event_type: String,
#[serde(default)]
message: Option<SseMessage>,
#[serde(default)]
index: Option<u32>,
#[serde(default)]
content_block: Option<SseContentBlock>,
#[serde(default)]
delta: Option<SseDelta>,
#[serde(default)]
usage: Option<SseUsage>,
}
#[derive(Debug, Deserialize)]
struct SseMessage {
#[serde(default)]
content: Option<Vec<SseContentBlock>>,
}
#[derive(Debug, Deserialize)]
struct SseContentBlock {
#[serde(rename = "type")]
content_type: String,
#[serde(default)]
thinking: Option<String>,
#[serde(default)]
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct SseDelta {
#[serde(rename = "type")]
delta_type: Option<String>,
#[serde(default)]
thinking: Option<String>,
#[serde(default)]
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct SseUsage {
#[serde(default)]
output_tokens: Option<u32>,
}
impl AnthropicClient {
/// Create new Anthropic client
pub fn new(api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
Ok(Self {
api_key: api_key.to_string(),
model: model.to_string(),
client,
thinking_enabled: false,
thinking_budget_tokens: 1024,
max_tokens: 500,
temperature: 0.7,
top_p: None,
thinking_state: None,
})
}
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?;
Ok(self)
}
/// List available models
pub fn with_thinking(mut self, enabled: bool) -> Self {
self.thinking_enabled = enabled;
self
}
pub fn with_thinking_budget_tokens(mut self, budget_tokens: u32) -> Self {
self.thinking_budget_tokens = budget_tokens;
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
self.thinking_state = Some(state);
self
}
pub async fn list_models(&self) -> Result<Vec<String>> {
// Anthropic doesn't have a models API endpoint, return predefined list
Ok(ANTHROPIC_MODELS.iter().map(|&m| m.to_string()).collect())
}
/// Validate API key
pub async fn validate_key(&self) -> Result<bool> {
let url = "https://api.anthropic.com/v1/messages";
@@ -84,14 +208,18 @@ impl AnthropicClient {
model: self.model.clone(),
max_tokens: 5,
temperature: Some(0.0),
top_p: None,
messages: vec![AnthropicMessage {
role: "user".to_string(),
content: "Hi".to_string(),
content: AnthropicContent::Text("Hi".to_string()),
}],
system: None,
thinking: None,
stream: false,
};
let response = self.client
let response = self
.client
.post(url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
@@ -124,25 +252,28 @@ impl LlmProvider for AnthropicClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![AnthropicMessage {
role: "user".to_string(),
content: prompt.to_string(),
content: AnthropicContent::Text(prompt.to_string()),
}];
self.messages_request(messages, None).await
self.messages_request_with_retry(messages, None).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let messages = vec![AnthropicMessage {
role: "user".to_string(),
content: user.to_string(),
content: AnthropicContent::Text(user.to_string()),
}];
let system = if system.is_empty() {
None
} else {
Some(system.to_string())
Some(vec![SystemContent {
content_type: "text".to_string(),
text: system.to_string(),
}])
};
self.messages_request(messages, system).await
self.messages_request_with_retry(messages, system).await
}
async fn is_available(&self) -> bool {
@@ -155,22 +286,84 @@ impl LlmProvider for AnthropicClient {
}
impl AnthropicClient {
async fn messages_request_with_retry(
&self,
messages: Vec<AnthropicMessage>,
system: Option<Vec<SystemContent>>,
) -> Result<String> {
let mut last_error = None;
for attempt in 1..=3 {
match self
.messages_request(messages.clone(), system.clone())
.await
{
Ok(result) => return Ok(result),
Err(e) => {
let err_msg = e.to_string();
let is_retryable = err_msg.contains("timeout")
|| err_msg.contains("connection")
|| err_msg.contains("temporary")
|| err_msg.contains("5")
&& (err_msg.contains("500")
|| err_msg.contains("502")
|| err_msg.contains("503")
|| err_msg.contains("504"));
if !is_retryable || attempt == 3 {
last_error = Some(e);
break;
}
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
}
}
}
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
}
async fn messages_request(
&self,
messages: Vec<AnthropicMessage>,
system: Option<String>,
system: Option<Vec<SystemContent>>,
) -> Result<String> {
if self.thinking_enabled {
self.streaming_messages_request(messages, system).await
} else {
self.non_streaming_messages_request(messages, system).await
}
}
async fn non_streaming_messages_request(
&self,
messages: Vec<AnthropicMessage>,
system: Option<Vec<SystemContent>>,
) -> Result<String> {
let url = "https://api.anthropic.com/v1/messages";
let temperature = if self.temperature == 0.0 {
None
} else {
Some(self.temperature)
};
let request = MessagesRequest {
model: self.model.clone(),
max_tokens: 500,
temperature: Some(0.7),
max_tokens: self.max_tokens,
temperature,
top_p: self.top_p,
messages,
system,
thinking: Some(ThinkingConfig {
thinking_type: "disabled".to_string(),
budget_tokens: None,
}),
stream: false,
};
let response = self.client
let response = self
.client
.post(url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
@@ -179,35 +372,205 @@ impl AnthropicClient {
.send()
.await
.context("Failed to send request to Anthropic")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
// Try to parse error
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!("Anthropic API error: {} ({})", error.error.message, error.error.error_type);
bail!(
"Anthropic API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("Anthropic API error: {} - {}", status, text);
}
let result: MessagesResponse = response
.json()
.await
.context("Failed to parse Anthropic response")?;
result.content
result
.content
.into_iter()
.find(|c| c.content_type == "text")
.map(|c| c.text.trim().to_string())
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("No text response from Anthropic"))
}
/// Streaming request for thinking mode, filters thinking content blocks
async fn streaming_messages_request(
&self,
messages: Vec<AnthropicMessage>,
system: Option<Vec<SystemContent>>,
) -> Result<String> {
let url = "https://api.anthropic.com/v1/messages";
let thinking = ThinkingConfig {
thinking_type: "enabled".to_string(),
budget_tokens: Some(self.thinking_budget_tokens),
};
// max_tokens must exceed budget_tokens
let max_tokens = (self.max_tokens).max(self.thinking_budget_tokens + 100);
let request = MessagesRequest {
model: self.model.clone(),
max_tokens,
temperature: None, // must be omitted for thinking mode
top_p: None,
messages,
system,
thinking: Some(thinking),
stream: true,
};
let response = self
.client
.post(url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.json(&request)
.send()
.await
.context("Failed to send streaming request to Anthropic")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"Anthropic API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("Anthropic API error: {} - {}", status, text);
}
let mut content_buffer = String::new();
let mut in_thinking = false;
let mut has_reasoning = false;
let mut has_content = false;
let thinking_state = self.thinking_state.as_ref();
let mut byte_stream = response.bytes_stream();
let mut line_buffer = String::new();
use futures_util::StreamExt;
while let Some(chunk) = byte_stream.next().await {
let chunk = chunk.context("Failed to read streaming response chunk")?;
let chunk_str =
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
line_buffer.push_str(&chunk_str);
while let Some(line_end) = line_buffer.find('\n') {
let line = line_buffer[..line_end].trim().to_string();
line_buffer = line_buffer[line_end + 1..].to_string();
if line.is_empty() {
continue;
}
// Parse SSE event line
if let Some(data) = line.strip_prefix("data: ") {
if let Ok(event) = serde_json::from_str::<SseEvent>(data) {
match event.event_type.as_str() {
"content_block_start" => {
if let Some(ref block) = event.content_block {
if block.content_type == "thinking" {
in_thinking = true;
if !has_reasoning {
has_reasoning = true;
if let Some(state) = thinking_state {
state.start_thinking();
}
}
}
}
}
"content_block_delta" => {
if let Some(ref delta) = event.delta {
// Thinking delta - ignore content but track state
if delta.thinking.is_some() {
continue;
}
// Text delta - collect
if in_thinking && delta.text.is_some() {
// Transition from thinking to text
if let Some(state) = thinking_state {
state.end_thinking();
}
in_thinking = false;
}
if let Some(ref text) = delta.text
&& !text.is_empty()
{
has_content = true;
content_buffer.push_str(text);
}
}
}
"content_block_stop" => {
if in_thinking {
if let Some(state) = thinking_state {
state.end_thinking();
}
in_thinking = false;
}
}
_ => {}
}
}
}
}
}
// Ensure thinking state is ended
if let Some(state) = thinking_state {
state.end_thinking();
}
let result = content_buffer.trim().to_string();
if result.is_empty() {
if has_reasoning && !has_content {
bail!(
"Anthropic returned thinking content but no final answer. \
The model may have entered an incomplete thinking state. \
Please try again or disable thinking mode."
);
}
bail!(
"No response from Anthropic. \
If thinking mode is enabled, try disabling it or ensure the model supports it."
);
}
Ok(result)
}
}
/// Available Anthropic models
/// Available Anthropic models (Claude 4 series with extended thinking)
pub const ANTHROPIC_MODELS: &[&str] = &[
"claude-opus-4-7",
"claude-sonnet-4-6",
"claude-haiku-4-5",
// Legacy models
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
@@ -216,7 +579,6 @@ pub const ANTHROPIC_MODELS: &[&str] = &[
"claude-instant-1.2",
];
/// Check if a model name is valid
pub fn is_valid_model(model: &str) -> bool {
ANTHROPIC_MODELS.contains(&model)
}
@@ -226,8 +588,68 @@ mod tests {
use super::*;
#[test]
fn test_model_validation() {
fn test_model_validation_claude4() {
assert!(is_valid_model("claude-opus-4-7"));
assert!(is_valid_model("claude-sonnet-4-6"));
assert!(is_valid_model("claude-haiku-4-5"));
assert!(is_valid_model("claude-3-sonnet-20240229"));
assert!(!is_valid_model("invalid-model"));
}
#[test]
fn test_thinking_config_serialization() {
let config = ThinkingConfig {
thinking_type: "enabled".to_string(),
budget_tokens: Some(2048),
};
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains(r#""type":"enabled""#));
assert!(json.contains(r#""budget_tokens":2048"#));
}
#[test]
fn test_thinking_config_disabled_serialization() {
let config = ThinkingConfig {
thinking_type: "disabled".to_string(),
budget_tokens: None,
};
let json = serde_json::to_string(&config).unwrap();
assert_eq!(json, r#"{"type":"disabled"}"#);
}
#[test]
fn test_system_content_serialization() {
let content = SystemContent {
content_type: "text".to_string(),
text: "You are helpful.".to_string(),
};
let json = serde_json::to_string(&content).unwrap();
assert!(json.contains(r#""type":"text""#));
}
#[test]
fn test_sse_event_parsing_content_block_start() {
let json = r#"{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}"#;
let event: SseEvent = serde_json::from_str(json).unwrap();
assert_eq!(event.event_type, "content_block_start");
assert_eq!(event.content_block.unwrap().content_type, "thinking");
}
#[test]
fn test_sse_event_parsing_text_delta() {
let json = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
let event: SseEvent = serde_json::from_str(json).unwrap();
assert_eq!(event.event_type, "content_block_delta");
assert_eq!(event.delta.unwrap().text, Some("Hello".to_string()));
}
#[test]
fn test_anthropic_content_text() {
let msg = AnthropicMessage {
role: "user".to_string(),
content: AnthropicContent::Text("Hello".to_string()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""content":"Hello""#));
}
}

View File

@@ -1,7 +1,9 @@
use super::{create_http_client, LlmProvider};
use anyhow::{bail, Context, Result};
use super::thinking::ThinkingStateManager;
use super::{LlmProvider, create_http_client};
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
/// DeepSeek API client
@@ -10,6 +12,11 @@ pub struct DeepSeekClient {
api_key: String,
model: String,
client: reqwest::Client,
thinking_enabled: bool,
reasoning_effort: Option<String>,
max_tokens: u32,
temperature: f32,
thinking_state: Option<Arc<ThinkingStateManager>>,
}
#[derive(Debug, Serialize)]
@@ -20,13 +27,31 @@ struct ChatCompletionRequest {
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
frequency_penalty: Option<f32>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<ThinkingConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_effort: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize)]
struct ThinkingConfig {
#[serde(rename = "type")]
thinking_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
}
#[derive(Debug, Deserialize)]
@@ -37,6 +62,31 @@ struct ChatCompletionResponse {
#[derive(Debug, Deserialize)]
struct Choice {
message: Message,
#[serde(default)]
reasoning_content: Option<String>,
}
// --- Streaming response structures ---
#[derive(Debug, Deserialize)]
struct StreamChunk {
choices: Vec<StreamChoice>,
}
#[derive(Debug, Deserialize)]
struct StreamChoice {
delta: StreamDelta,
#[serde(default)]
finish_reason: Option<String>,
index: Option<u32>,
}
#[derive(Debug, Deserialize, Default)]
struct StreamDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
}
#[derive(Debug, Deserialize)]
@@ -52,41 +102,73 @@ struct ApiError {
}
impl DeepSeekClient {
/// Create new DeepSeek client
pub fn new(api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
let client = create_http_client(Duration::from_secs(300))?;
Ok(Self {
base_url: "https://api.deepseek.com/v1".to_string(),
base_url: "https://api.deepseek.com".to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
thinking_enabled: false,
reasoning_effort: None,
max_tokens: 500,
temperature: 0.7,
thinking_state: None,
})
}
/// Create with custom base URL
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
let client = create_http_client(Duration::from_secs(300))?;
Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
thinking_enabled: false,
reasoning_effort: None,
max_tokens: 500,
temperature: 0.7,
thinking_state: None,
})
}
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?;
Ok(self)
}
/// List available models
pub fn with_thinking(mut self, enabled: bool) -> Self {
self.thinking_enabled = enabled;
self
}
pub fn with_reasoning_effort(mut self, effort: Option<String>) -> Self {
self.reasoning_effort = effort;
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
self.thinking_state = Some(state);
self
}
pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/models", self.base_url);
let response = self.client
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
@@ -101,11 +183,11 @@ impl DeepSeekClient {
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<Model>,
data: Vec<ModelId>,
}
#[derive(Deserialize)]
struct Model {
struct ModelId {
id: String,
}
@@ -117,7 +199,6 @@ impl DeepSeekClient {
Ok(result.data.into_iter().map(|m| m.id).collect())
}
/// Validate API key
pub async fn validate_key(&self) -> Result<bool> {
match self.list_models().await {
Ok(_) => Ok(true),
@@ -136,32 +217,33 @@ impl DeepSeekClient {
#[async_trait]
impl LlmProvider for DeepSeekClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![
Message {
role: "user".to_string(),
content: prompt.to_string(),
},
];
self.chat_completion(messages).await
let messages = vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
reasoning_content: None,
}];
self.chat_completion_with_retry(messages).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![];
if !system.is_empty() {
messages.push(Message {
role: "system".to_string(),
content: system.to_string(),
reasoning_content: None,
});
}
messages.push(Message {
role: "user".to_string(),
content: user.to_string(),
reasoning_content: None,
});
self.chat_completion(messages).await
self.chat_completion_with_retry(messages).await
}
async fn is_available(&self) -> bool {
@@ -174,59 +256,291 @@ impl LlmProvider for DeepSeekClient {
}
impl DeepSeekClient {
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
let mut last_error = None;
for attempt in 1..=3 {
match self.chat_completion(messages.clone()).await {
Ok(result) => return Ok(result),
Err(e) => {
let err_msg = e.to_string();
// 网络临时错误才重试
let is_retryable = err_msg.contains("timeout")
|| err_msg.contains("connection")
|| err_msg.contains("temporary")
|| err_msg.contains("5")
&& (err_msg.contains("500")
|| err_msg.contains("502")
|| err_msg.contains("503")
|| err_msg.contains("504"));
if !is_retryable || attempt == 3 {
last_error = Some(e);
break;
}
// 指数退避
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
}
}
}
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
}
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
let thinking = Some(ThinkingConfig {
thinking_type: if self.thinking_enabled {
"enabled".to_string()
} else {
"disabled".to_string()
},
});
// 思考模式下temperature/top_p 等参数不应传递
// 非思考模式下可以正常传递
let (temperature, max_tokens, top_p, presence_penalty, frequency_penalty) =
if self.thinking_enabled {
(None, Some(self.max_tokens), None, None, None)
} else {
(
Some(self.temperature),
Some(self.max_tokens),
None,
None,
None,
)
};
let reasoning_effort = if self.thinking_enabled {
self.reasoning_effort.clone()
} else {
None
};
let request = ChatCompletionRequest {
model: self.model.clone(),
messages,
max_tokens: Some(500),
temperature: Some(0.7),
stream: false,
messages: messages.clone(),
max_tokens,
temperature,
top_p,
presence_penalty,
frequency_penalty,
stream: self.thinking_enabled,
thinking,
reasoning_effort,
};
let response = self.client
.post(&url)
if self.thinking_enabled {
self.streaming_chat_completion(&url, &request).await
} else {
self.non_streaming_chat_completion(&url, &request).await
}
}
/// 非流式请求(非思考模式)
async fn non_streaming_chat_completion(
&self,
url: &str,
request: &ChatCompletionRequest,
) -> Result<String> {
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.json(request)
.send()
.await
.context("Failed to send request to DeepSeek")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
// Try to parse error
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!("DeepSeek API error: {} ({})", error.error.message, error.error.error_type);
bail!(
"DeepSeek API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("DeepSeek API error: {} - {}", status, text);
}
let result: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse DeepSeek response")?;
result.choices
result
.choices
.into_iter()
.next()
.map(|c| c.message.content.trim().to_string())
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("No response from DeepSeek"))
}
/// 流式请求(思考模式),处理 reasoning_content 和 content
async fn streaming_chat_completion(
&self,
url: &str,
request: &ChatCompletionRequest,
) -> Result<String> {
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.json(request)
.send()
.await
.context("Failed to send streaming request to DeepSeek")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"DeepSeek API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("DeepSeek API error: {} - {}", status, text);
}
let mut content_buffer = String::new();
let mut has_reasoning = false;
let mut has_content = false;
let mut stream_ended = false;
let thinking_state = self.thinking_state.as_ref();
let mut byte_stream = response.bytes_stream();
let mut line_buffer = String::new();
use futures_util::StreamExt;
while let Some(chunk) = byte_stream.next().await {
let chunk = chunk.context("Failed to read streaming response chunk")?;
let chunk_str =
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
line_buffer.push_str(&chunk_str);
// 处理完整行
while let Some(line_end) = line_buffer.find('\n') {
let line = line_buffer[..line_end].trim().to_string();
line_buffer = line_buffer[line_end + 1..].to_string();
if line.is_empty() {
continue;
}
// SSE 格式data: {...} 或 data: [DONE]
if line == "data: [DONE]" {
stream_ended = true;
break;
}
if let Some(json_str) = line.strip_prefix("data: ") {
match serde_json::from_str::<StreamChunk>(json_str) {
Ok(chunk) => {
for choice in &chunk.choices {
// 处理 reasoning_content
if let Some(ref reasoning) = choice.delta.reasoning_content
&& !reasoning.is_empty()
{
if !has_reasoning {
has_reasoning = true;
if let Some(state) = thinking_state {
state.start_thinking();
}
}
// reasoning_content 不对外输出,仅用于内部状态判断
continue;
}
// 处理 content
if let Some(ref content) = choice.delta.content
&& !content.is_empty()
{
// reasoning 结束content 开始出现时移除 thinking 标识
if has_reasoning
&& !has_content
&& let Some(state) = thinking_state
{
state.end_thinking();
}
has_content = true;
content_buffer.push_str(content);
}
// 检查 finish_reason
if let Some(ref reason) = choice.finish_reason
&& reason == "stop"
{
stream_ended = true;
}
}
}
Err(_) => {
// 忽略无法解析的行(可能是心跳或注释)
}
}
}
}
if stream_ended {
break;
}
}
// 确保思考状态已结束
if let Some(state) = thinking_state {
state.end_thinking();
}
let result = content_buffer.trim().to_string();
if result.is_empty() {
if has_reasoning && !has_content {
bail!(
"DeepSeek returned reasoning content but no final answer. \
The model may have entered an incomplete thinking state. \
Please try again or disable thinking mode."
);
}
bail!(
"No response from DeepSeek. \
If thinking mode is enabled, try disabling it or ensure the model supports it."
);
}
Ok(result)
}
}
/// Available DeepSeek models
/// 可用 DeepSeek 模型列表
/// deepseek-chat / deepseek-reasoner 将于 2026-07-24 停用,推荐使用 V4 系列
pub const DEEPSEEK_MODELS: &[&str] = &[
"deepseek-v4-flash",
"deepseek-v4-pro",
// 兼容旧版模型 ID将于 2026-07-24 停用)
"deepseek-chat",
"deepseek-coder",
"deepseek-reasoner",
];
/// Check if a model name is valid
pub fn is_valid_model(model: &str) -> bool {
DEEPSEEK_MODELS.contains(&model)
}
@@ -236,8 +550,73 @@ mod tests {
use super::*;
#[test]
fn test_model_validation() {
fn test_model_validation_v4() {
assert!(is_valid_model("deepseek-v4-flash"));
assert!(is_valid_model("deepseek-v4-pro"));
assert!(is_valid_model("deepseek-chat"));
assert!(is_valid_model("deepseek-reasoner"));
assert!(!is_valid_model("invalid-model"));
assert!(!is_valid_model("deepseek-v3"));
}
}
#[test]
fn test_client_builder_defaults() {
let client = DeepSeekClient::new("test-key", "deepseek-v4-flash").unwrap();
assert!(!client.thinking_enabled);
assert_eq!(client.max_tokens, 500);
assert_eq!(client.temperature, 0.7);
assert!(client.reasoning_effort.is_none());
assert!(client.thinking_state.is_none());
}
#[test]
fn test_client_builder_with_thinking() {
let client = DeepSeekClient::new("test-key", "deepseek-v4-flash")
.unwrap()
.with_thinking(true)
.with_reasoning_effort(Some("high".to_string()))
.with_max_tokens(1000)
.with_temperature(0.5);
assert!(client.thinking_enabled);
assert_eq!(client.reasoning_effort, Some("high".to_string()));
assert_eq!(client.max_tokens, 1000);
assert_eq!(client.temperature, 0.5);
}
#[test]
fn test_thinking_config_serialization() {
let config = ThinkingConfig {
thinking_type: "enabled".to_string(),
};
let json = serde_json::to_string(&config).unwrap();
assert_eq!(json, r#"{"type":"enabled"}"#);
}
#[test]
fn test_message_serialization_without_reasoning() {
let msg = Message {
role: "user".to_string(),
content: "Hello".to_string(),
reasoning_content: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(!json.contains("reasoning_content"));
}
#[test]
fn test_stream_delta_parsing() {
let json = r#"{"content":"Hello","reasoning_content":null}"#;
let delta: StreamDelta = serde_json::from_str(json).unwrap();
assert_eq!(delta.content, Some("Hello".to_string()));
assert!(delta.reasoning_content.is_none());
}
#[test]
fn test_stream_delta_reasoning_only() {
let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#;
let delta: StreamDelta = serde_json::from_str(json).unwrap();
assert!(delta.content.is_none());
assert_eq!(delta.reasoning_content, Some("Let me think...".to_string()));
}
}

View File

@@ -1,244 +1,587 @@
use super::{create_http_client, LlmProvider};
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Duration;
/// Kimi API client (Moonshot AI)
pub struct KimiClient {
base_url: String,
api_key: String,
model: String,
client: reqwest::Client,
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: Message,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: ApiError,
}
#[derive(Debug, Deserialize)]
struct ApiError {
message: String,
#[serde(rename = "type")]
error_type: String,
}
impl KimiClient {
/// Create new Kimi client
pub fn new(api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
Ok(Self {
base_url: "https://api.moonshot.cn/v1".to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
})
}
/// Create with custom base URL
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
})
}
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?;
Ok(self)
}
/// List available models
pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/models", self.base_url);
let response = self.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
.context("Failed to list Kimi models")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
bail!("Kimi API error: {} - {}", status, text);
}
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<Model>,
}
#[derive(Deserialize)]
struct Model {
id: String,
}
let result: ModelsResponse = response
.json()
.await
.context("Failed to parse Kimi response")?;
Ok(result.data.into_iter().map(|m| m.id).collect())
}
/// Validate API key
pub async fn validate_key(&self) -> Result<bool> {
match self.list_models().await {
Ok(_) => Ok(true),
Err(e) => {
let err_str = e.to_string();
if err_str.contains("401") || err_str.contains("Unauthorized") {
Ok(false)
} else {
Err(e)
}
}
}
}
}
#[async_trait]
impl LlmProvider for KimiClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![
Message {
role: "user".to_string(),
content: prompt.to_string(),
},
];
self.chat_completion(messages).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![];
if !system.is_empty() {
messages.push(Message {
role: "system".to_string(),
content: system.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: user.to_string(),
});
self.chat_completion(messages).await
}
async fn is_available(&self) -> bool {
self.validate_key().await.unwrap_or(false)
}
fn name(&self) -> &str {
"kimi"
}
}
impl KimiClient {
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
let request = ChatCompletionRequest {
model: self.model.clone(),
messages,
max_tokens: Some(500),
temperature: Some(0.7),
stream: false,
};
let response = self.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.context("Failed to send request to Kimi")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
// Try to parse error
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!("Kimi API error: {} ({})", error.error.message, error.error.error_type);
}
bail!("Kimi API error: {} - {}", status, text);
}
let result: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse Kimi response")?;
result.choices
.into_iter()
.next()
.map(|c| c.message.content.trim().to_string())
.ok_or_else(|| anyhow::anyhow!("No response from Kimi"))
}
}
/// Available Kimi models
pub const KIMI_MODELS: &[&str] = &[
"moonshot-v1-8k",
"moonshot-v1-32k",
"moonshot-v1-128k",
];
/// Check if a model name is valid
pub fn is_valid_model(model: &str) -> bool {
KIMI_MODELS.contains(&model)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_validation() {
assert!(is_valid_model("moonshot-v1-8k"));
assert!(!is_valid_model("invalid-model"));
}
}
use super::thinking::ThinkingStateManager;
use super::{LlmProvider, create_http_client};
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
/// Kimi API client (Moonshot AI)
pub struct KimiClient {
base_url: String,
api_key: String,
model: String,
client: reqwest::Client,
thinking_enabled: bool,
max_tokens: u32,
temperature: f32,
thinking_state: Option<Arc<ThinkingStateManager>>,
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<ThinkingConfig>,
}
#[derive(Debug, Serialize)]
struct ThinkingConfig {
#[serde(rename = "type")]
thinking_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: Message,
#[serde(default)]
reasoning_content: Option<String>,
}
// --- Streaming response structures ---
#[derive(Debug, Deserialize)]
struct StreamChunk {
choices: Vec<StreamChoice>,
}
#[derive(Debug, Deserialize)]
struct StreamChoice {
delta: StreamDelta,
#[serde(default)]
finish_reason: Option<String>,
index: Option<u32>,
}
#[derive(Debug, Deserialize, Default)]
struct StreamDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: ApiError,
}
#[derive(Debug, Deserialize)]
struct ApiError {
message: String,
#[serde(rename = "type")]
error_type: String,
}
impl KimiClient {
pub fn new(api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(300))?;
Ok(Self {
base_url: "https://api.moonshot.cn/v1".to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
thinking_enabled: false,
max_tokens: 500,
temperature: 1.0,
thinking_state: None,
})
}
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(300))?;
Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
thinking_enabled: false,
max_tokens: 500,
temperature: 1.0,
thinking_state: None,
})
}
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?;
Ok(self)
}
pub fn with_thinking(mut self, enabled: bool) -> Self {
self.thinking_enabled = enabled;
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
self.thinking_state = Some(state);
self
}
pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/models", self.base_url);
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
.context("Failed to list Kimi models")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
bail!("Kimi API error: {} - {}", status, text);
}
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<ModelId>,
}
#[derive(Deserialize)]
struct ModelId {
id: String,
}
let result: ModelsResponse = response
.json()
.await
.context("Failed to parse Kimi response")?;
Ok(result.data.into_iter().map(|m| m.id).collect())
}
pub async fn validate_key(&self) -> Result<bool> {
match self.list_models().await {
Ok(_) => Ok(true),
Err(e) => {
let err_str = e.to_string();
if err_str.contains("401") || err_str.contains("Unauthorized") {
Ok(false)
} else {
Err(e)
}
}
}
}
}
#[async_trait]
impl LlmProvider for KimiClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
reasoning_content: None,
}];
self.chat_completion_with_retry(messages).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![];
if !system.is_empty() {
messages.push(Message {
role: "system".to_string(),
content: system.to_string(),
reasoning_content: None,
});
}
messages.push(Message {
role: "user".to_string(),
content: user.to_string(),
reasoning_content: None,
});
self.chat_completion_with_retry(messages).await
}
async fn is_available(&self) -> bool {
self.validate_key().await.unwrap_or(false)
}
fn name(&self) -> &str {
"kimi"
}
}
impl KimiClient {
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
let mut last_error = None;
for attempt in 1..=3 {
match self.chat_completion(messages.clone()).await {
Ok(result) => return Ok(result),
Err(e) => {
let err_msg = e.to_string();
let is_retryable = err_msg.contains("timeout")
|| err_msg.contains("connection")
|| err_msg.contains("temporary")
|| err_msg.contains("5")
&& (err_msg.contains("500")
|| err_msg.contains("502")
|| err_msg.contains("503")
|| err_msg.contains("504"));
if !is_retryable || attempt == 3 {
last_error = Some(e);
break;
}
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
}
}
}
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
}
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
let thinking = Some(ThinkingConfig {
thinking_type: if self.thinking_enabled {
"enabled".to_string()
} else {
"disabled".to_string()
},
});
// Kimi API temperature 要求:
// - 思考模式: temperature 必须为 1.0
// - 非思考模式: temperature 必须为 0.6
let temperature = if self.thinking_enabled {
Some(1.0)
} else {
Some(0.6)
};
let request = ChatCompletionRequest {
model: self.model.clone(),
messages: messages.clone(),
max_tokens: Some(self.max_tokens),
temperature,
stream: self.thinking_enabled,
thinking,
};
if self.thinking_enabled {
self.streaming_chat_completion(&url, &request).await
} else {
self.non_streaming_chat_completion(&url, &request).await
}
}
/// 非流式请求(非思考模式)
async fn non_streaming_chat_completion(
&self,
url: &str,
request: &ChatCompletionRequest,
) -> Result<String> {
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(request)
.send()
.await
.context("Failed to send request to Kimi")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"Kimi API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("Kimi API error: {} - {}", status, text);
}
let result: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse Kimi response")?;
result
.choices
.into_iter()
.next()
.map(|c| {
let content = c.message.content.trim().to_string();
if content.is_empty() {
c.reasoning_content
.or(c.message.reasoning_content)
.map(|r| r.trim().to_string())
.unwrap_or_default()
} else {
content
}
})
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("No response from Kimi"))
}
/// 流式请求(思考模式),处理 reasoning_content 和 content
async fn streaming_chat_completion(
&self,
url: &str,
request: &ChatCompletionRequest,
) -> Result<String> {
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.json(request)
.send()
.await
.context("Failed to send streaming request to Kimi")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"Kimi API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("Kimi API error: {} - {}", status, text);
}
let mut content_buffer = String::new();
let mut has_reasoning = false;
let mut has_content = false;
let mut stream_ended = false;
let thinking_state = self.thinking_state.as_ref();
let mut byte_stream = response.bytes_stream();
let mut line_buffer = String::new();
use futures_util::StreamExt;
while let Some(chunk) = byte_stream.next().await {
let chunk = chunk.context("Failed to read streaming response chunk")?;
let chunk_str =
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
line_buffer.push_str(&chunk_str);
while let Some(line_end) = line_buffer.find('\n') {
let line = line_buffer[..line_end].trim().to_string();
line_buffer = line_buffer[line_end + 1..].to_string();
if line.is_empty() {
continue;
}
if line == "data: [DONE]" {
stream_ended = true;
break;
}
if let Some(json_str) = line.strip_prefix("data: ") {
match serde_json::from_str::<StreamChunk>(json_str) {
Ok(chunk) => {
for choice in &chunk.choices {
if let Some(ref reasoning) = choice.delta.reasoning_content
&& !reasoning.is_empty()
{
if !has_reasoning {
has_reasoning = true;
if let Some(state) = thinking_state {
state.start_thinking();
}
}
continue;
}
if let Some(ref content) = choice.delta.content
&& !content.is_empty()
{
if has_reasoning
&& !has_content
&& let Some(state) = thinking_state
{
state.end_thinking();
}
has_content = true;
content_buffer.push_str(content);
}
if let Some(ref reason) = choice.finish_reason
&& reason == "stop"
{
stream_ended = true;
}
}
}
Err(_) => {
// 忽略无法解析的行
}
}
}
}
if stream_ended {
break;
}
}
// 确保思考状态已结束
if let Some(state) = thinking_state {
state.end_thinking();
}
let result = content_buffer.trim().to_string();
if result.is_empty() {
if has_reasoning && !has_content {
bail!(
"Kimi returned reasoning content but no final answer. \
The model may have entered an incomplete thinking state. \
Please try again or disable thinking mode."
);
}
bail!(
"No response from Kimi. \
If thinking mode is enabled, try disabling it or ensure the model supports it."
);
}
Ok(result)
}
}
/// 可用 Kimi 模型列表
pub const KIMI_MODELS: &[&str] = &[
// K2 系列(推荐)
"kimi-k2.6",
"kimi-k2.5",
"kimi-k2-thinking",
"kimi-k2-thinking-turbo",
"kimi-k2-instruct",
"kimi-k2-instruct-0905",
// 兼容旧版模型 ID
"moonshot-v1-8k",
"moonshot-v1-32k",
"moonshot-v1-128k",
];
pub fn is_valid_model(model: &str) -> bool {
KIMI_MODELS.contains(&model)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_validation_k2() {
assert!(is_valid_model("kimi-k2.6"));
assert!(is_valid_model("kimi-k2.5"));
assert!(is_valid_model("kimi-k2-thinking"));
assert!(is_valid_model("kimi-k2-thinking-turbo"));
assert!(is_valid_model("moonshot-v1-8k"));
assert!(is_valid_model("moonshot-v1-32k"));
assert!(is_valid_model("moonshot-v1-128k"));
assert!(!is_valid_model("invalid-model"));
assert!(!is_valid_model("kimi-k1.5"));
}
#[test]
fn test_client_builder_defaults() {
let client = KimiClient::new("test-key", "kimi-k2.6").unwrap();
assert!(!client.thinking_enabled);
assert_eq!(client.max_tokens, 500);
assert_eq!(client.temperature, 1.0);
assert!(client.thinking_state.is_none());
}
#[test]
fn test_client_builder_with_thinking() {
let client = KimiClient::new("test-key", "kimi-k2.6")
.unwrap()
.with_thinking(true)
.with_max_tokens(1000)
.with_temperature(0.5);
assert!(client.thinking_enabled);
assert_eq!(client.max_tokens, 1000);
assert_eq!(client.temperature, 0.5);
}
#[test]
fn test_thinking_config_serialization() {
let config = ThinkingConfig {
thinking_type: "enabled".to_string(),
};
let json = serde_json::to_string(&config).unwrap();
assert_eq!(json, r#"{"type":"enabled"}"#);
}
#[test]
fn test_client_new_defaults() {
let client = KimiClient::new("test-key", "kimi-k2.6").unwrap();
assert_eq!(client.name(), "kimi");
assert!(!client.thinking_enabled);
}
#[test]
fn test_message_serialization() {
let msg = Message {
role: "user".to_string(),
content: "Hello".to_string(),
reasoning_content: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(!json.contains("reasoning_content"));
}
}

View File

@@ -1,20 +1,21 @@
use anyhow::{bail, Context, Result};
use crate::config::Language;
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use std::time::Duration;
use crate::config::Language;
pub mod anthropic;
pub mod deepseek;
pub mod kimi;
pub mod ollama;
pub mod openai;
pub mod anthropic;
pub mod kimi;
pub mod deepseek;
pub mod openrouter;
pub mod thinking;
pub use anthropic::AnthropicClient;
pub use deepseek::DeepSeekClient;
pub use kimi::KimiClient;
pub use ollama::OllamaClient;
pub use openai::OpenAiClient;
pub use anthropic::AnthropicClient;
pub use kimi::KimiClient;
pub use deepseek::DeepSeekClient;
pub use openrouter::OpenRouterClient;
/// LLM provider trait
@@ -22,13 +23,13 @@ pub use openrouter::OpenRouterClient;
pub trait LlmProvider: Send + Sync {
/// Generate text from prompt
async fn generate(&self, prompt: &str) -> Result<String>;
/// Generate with system prompt
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String>;
/// Check if provider is available
async fn is_available(&self) -> bool;
/// Get provider name
fn name(&self) -> &str;
}
@@ -44,6 +45,7 @@ pub struct LlmClientConfig {
pub max_tokens: u32,
pub temperature: f32,
pub timeout: Duration,
pub thinking_enabled: bool,
}
impl Default for LlmClientConfig {
@@ -52,53 +54,131 @@ impl Default for LlmClientConfig {
max_tokens: 500,
temperature: 0.7,
timeout: Duration::from_secs(30),
thinking_enabled: false,
}
}
}
impl LlmClient {
/// Create LLM client from configuration
pub async fn from_config(config: &crate::config::LlmConfig) -> Result<Self> {
/// Create LLM client from configuration manager
pub async fn from_config(manager: &crate::config::manager::ConfigManager) -> Result<Self> {
Self::from_config_with_think(manager, manager.config().llm.thinking_enabled).await
}
/// Create LLM client from configuration with explicit thinking override
pub async fn from_config_with_think(
manager: &crate::config::manager::ConfigManager,
thinking_enabled: bool,
) -> Result<Self> {
let config = manager.config();
let client_config = LlmClientConfig {
max_tokens: config.max_tokens,
temperature: config.temperature,
timeout: Duration::from_secs(config.timeout),
max_tokens: config.llm.max_tokens,
temperature: config.llm.temperature,
timeout: Duration::from_secs(config.llm.timeout),
thinking_enabled,
};
let provider: Box<dyn LlmProvider> = match config.provider.as_str() {
"ollama" => {
Box::new(OllamaClient::new(&config.ollama.url, &config.ollama.model))
}
let provider = config.llm.provider.as_str();
let model = config.llm.model.as_str();
let base_url = manager.llm_base_url();
let api_key = manager.get_api_key();
let provider: Box<dyn LlmProvider> = match provider {
"ollama" => Box::new(
OllamaClient::new(&base_url, model)
.with_max_tokens(client_config.max_tokens)
.with_temperature(client_config.temperature),
),
"openai" => {
let api_key = config.openai.api_key.as_ref()
let key = api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenAI API key not configured"))?;
Box::new(OpenAiClient::new(
&config.openai.base_url,
api_key,
&config.openai.model,
)?)
let thinking_state = if thinking_enabled {
Some(thinking::create_console_thinking_state())
} else {
None
};
let mut client = OpenAiClient::new(&base_url, key, model)?
.with_thinking(thinking_enabled)
.with_max_tokens(client_config.max_tokens)
.with_temperature(client_config.temperature)
.with_timeout(client_config.timeout)?;
if let Some(state) = thinking_state {
client = client.with_thinking_state(state);
}
Box::new(client)
}
"anthropic" => {
let api_key = config.anthropic.api_key.as_ref()
let key = api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Anthropic API key not configured"))?;
Box::new(AnthropicClient::new(api_key, &config.anthropic.model)?)
let thinking_state = if thinking_enabled {
Some(thinking::create_console_thinking_state())
} else {
None
};
let budget = config.llm.thinking_budget_tokens.unwrap_or(1024);
let mut client = AnthropicClient::new(key, model)?
.with_thinking(thinking_enabled)
.with_thinking_budget_tokens(budget)
.with_max_tokens(client_config.max_tokens)
.with_temperature(client_config.temperature)
.with_timeout(client_config.timeout)?;
if let Some(state) = thinking_state {
client = client.with_thinking_state(state);
}
Box::new(client)
}
"kimi" => {
let api_key = config.kimi.api_key.as_ref()
let key = api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Kimi API key not configured"))?;
Box::new(KimiClient::with_base_url(api_key, &config.kimi.model, &config.kimi.base_url)?)
let thinking_state = if thinking_enabled {
Some(thinking::create_console_thinking_state())
} else {
None
};
let mut client = KimiClient::with_base_url(key, model, &base_url)?
.with_thinking(thinking_enabled)
.with_max_tokens(client_config.max_tokens)
.with_temperature(client_config.temperature)
.with_timeout(client_config.timeout)?;
if let Some(state) = thinking_state {
client = client.with_thinking_state(state);
}
Box::new(client)
}
"deepseek" => {
let api_key = config.deepseek.api_key.as_ref()
let key = api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("DeepSeek API key not configured"))?;
Box::new(DeepSeekClient::with_base_url(api_key, &config.deepseek.model, &config.deepseek.base_url)?)
let thinking_state = if thinking_enabled {
Some(thinking::create_console_thinking_state())
} else {
None
};
let mut client = DeepSeekClient::with_base_url(key, model, &base_url)?
.with_thinking(thinking_enabled)
.with_max_tokens(client_config.max_tokens)
.with_temperature(client_config.temperature)
.with_timeout(client_config.timeout)?;
if let Some(state) = thinking_state {
client = client.with_thinking_state(state);
}
Box::new(client)
}
"openrouter" => {
let api_key = config.openrouter.api_key.as_ref()
let key = api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not configured"))?;
Box::new(OpenRouterClient::with_base_url(api_key, &config.openrouter.model, &config.openrouter.base_url)?)
Box::new(
OpenRouterClient::with_base_url(key, model, &base_url)?
.with_max_tokens(client_config.max_tokens)
.with_temperature(client_config.temperature)
.with_timeout(client_config.timeout)?,
)
}
_ => bail!("Unknown LLM provider: {}", config.provider),
_ => bail!("Unknown LLM provider: {}", provider),
};
Ok(Self {
@@ -123,7 +203,7 @@ impl LlmClient {
language: Language,
) -> Result<GeneratedCommit> {
let system_prompt = get_commit_system_prompt(format, language);
// Add language instruction to the prompt
let language_instruction = match language {
Language::Chinese => "\n\n请用中文生成提交消息。",
@@ -134,10 +214,13 @@ impl LlmClient {
Language::German => "\n\nBitte generieren Sie die Commit-Nachricht auf Deutsch.",
Language::English => "",
};
let prompt = format!("{}{}", diff, language_instruction);
let response = self.provider.generate_with_system(system_prompt, &prompt).await?;
let response = self
.provider
.generate_with_system(system_prompt, &prompt)
.await?;
self.parse_commit_response(&response, format)
}
@@ -150,7 +233,7 @@ impl LlmClient {
) -> Result<String> {
let system_prompt = get_tag_system_prompt(language);
let commits_text = commits.join("\n");
// Add language instruction to the prompt
let language_instruction = match language {
Language::Chinese => "\n\n请用中文生成标签消息。",
@@ -161,10 +244,15 @@ impl LlmClient {
Language::German => "\n\nBitte generieren Sie die Tag-Nachricht auf Deutsch.",
Language::English => "",
};
let prompt = format!("Version: {}\n\nCommits:\n{}{}", version, commits_text, language_instruction);
self.provider.generate_with_system(system_prompt, &prompt).await
let prompt = format!(
"Version: {}\n\nCommits:\n{}{}",
version, commits_text, language_instruction
);
self.provider
.generate_with_system(system_prompt, &prompt)
.await
}
/// Generate changelog entry
@@ -175,13 +263,13 @@ impl LlmClient {
language: Language,
) -> Result<String> {
let system_prompt = get_changelog_system_prompt(language);
let commits_text = commits
.iter()
.map(|(t, m)| format!("- [{}] {}", t, m))
.collect::<Vec<_>>()
.join("\n");
// Add language instruction to the prompt
let language_instruction = match language {
Language::Chinese => "\n\n请用中文生成变更日志。",
@@ -192,10 +280,15 @@ impl LlmClient {
Language::German => "\n\nBitte generieren Sie das Changelog auf Deutsch.",
Language::English => "",
};
let prompt = format!("Version: {}\n\nCommits:\n{}{}", version, commits_text, language_instruction);
self.provider.generate_with_system(system_prompt, &prompt).await
let prompt = format!(
"Version: {}\n\nCommits:\n{}{}",
version, commits_text, language_instruction
);
self.provider
.generate_with_system(system_prompt, &prompt)
.await
}
/// Check if provider is available
@@ -204,35 +297,115 @@ impl LlmClient {
}
/// Parse commit response from LLM
fn parse_commit_response(&self, response: &str, format: crate::config::CommitFormat) -> Result<GeneratedCommit> {
let lines: Vec<&str> = response.lines().collect();
fn parse_commit_response(
&self,
response: &str,
format: crate::config::CommitFormat,
) -> Result<GeneratedCommit> {
// Clean markdown code fences from the response
let cleaned = Self::strip_code_fences(response);
let lines: Vec<&str> = cleaned
.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.collect();
if lines.is_empty() {
bail!("Empty response from LLM");
let preview: String = response.chars().take(200).collect();
bail!(
"LLM returned empty or whitespace-only response. \
Raw response preview: '{}'. \
Hint: If using DeepSeek/Kimi with thinking enabled, \
the model may have returned reasoning_content only. \
Try disabling thinking mode or switching models.",
preview
);
}
let first_line = lines[0];
// Find the line most likely to be the commit subject
let first_line = Self::find_commit_subject_line(&lines, format);
// Parse based on format
match format {
crate::config::CommitFormat::Conventional => {
self.parse_conventional_commit(first_line, lines)
self.parse_conventional_commit(first_line, &lines, response)
}
crate::config::CommitFormat::Commitlint => {
self.parse_commitlint_commit(first_line, lines)
self.parse_commitlint_commit(first_line, &lines, response)
}
}
}
/// Remove surrounding markdown code fences (```) from LLM output
fn strip_code_fences(response: &str) -> String {
let mut lines: Vec<&str> = response.lines().collect();
// Strip leading fence lines (``` or ```lang)
while lines.first().map_or(false, |l| l.trim().starts_with("```")) {
lines.remove(0);
}
// Strip trailing fence lines
while lines.last().map_or(false, |l| l.trim() == "```") {
lines.pop();
}
lines.join("\n")
}
/// Find the line that is most likely the commit subject among extracted lines
fn find_commit_subject_line<'a>(
lines: &[&'a str],
format: crate::config::CommitFormat,
) -> &'a str {
let valid_types = crate::utils::validators::get_commit_types(matches!(
format,
crate::config::CommitFormat::Commitlint
));
// First pass: line starting with a known type that also has proper syntax
// (e.g. "type:", "type(scope):", "type!:")
for &line in lines {
let trimmed = line.trim();
for &t in valid_types {
if let Some(rest) = trimmed.strip_prefix(t) {
if rest.starts_with(':') || rest.starts_with('(') || rest.starts_with("!:") {
return trimmed;
}
}
}
}
// Second pass: any line containing a colon (generic "prefix: description")
for &line in lines {
if line.contains(':') {
return line.trim();
}
}
// Fallback: return the first line as-is
lines[0].trim()
}
fn parse_conventional_commit(
&self,
first_line: &str,
lines: Vec<&str>,
lines: &[&str],
raw_response: &str,
) -> Result<GeneratedCommit> {
// Parse: type(scope)!: description
let parts: Vec<&str> = first_line.splitn(2, ':').collect();
if parts.len() != 2 {
bail!("Invalid conventional commit format: missing colon");
let preview: String = raw_response.chars().take(300).collect();
bail!(
"Invalid conventional commit format: missing colon.\n\
Parsed subject line: '{}'\n\
Raw response preview: '{}'\n\
Expected: <type>[optional scope]: <description>",
first_line,
preview
);
}
let type_part = parts[0];
@@ -255,7 +428,7 @@ impl LlmClient {
};
// Extract body and footer
let (body, footer) = self.extract_body_footer(&lines);
let (body, footer) = self.extract_body_footer(lines);
Ok(GeneratedCommit {
commit_type,
@@ -270,12 +443,21 @@ impl LlmClient {
fn parse_commitlint_commit(
&self,
first_line: &str,
lines: Vec<&str>,
lines: &[&str],
raw_response: &str,
) -> Result<GeneratedCommit> {
// Similar parsing but with commitlint rules
let parts: Vec<&str> = first_line.splitn(2, ':').collect();
if parts.len() != 2 {
bail!("Invalid commit format: missing colon");
let preview: String = raw_response.chars().take(300).collect();
bail!(
"Invalid commit format: missing colon.\n\
Parsed subject line: '{}'\n\
Raw response preview: '{}'\n\
Expected: <type>[optional scope]: <subject>",
first_line,
preview
);
}
let type_part = parts[0];
@@ -321,8 +503,14 @@ impl LlmClient {
}
// Look for footer markers
let footer_markers = ["BREAKING CHANGE:", "Closes", "Fixes", "Refs", "Co-authored-by:"];
let footer_markers = [
"BREAKING CHANGE:",
"Closes",
"Fixes",
"Refs",
"Co-authored-by:",
];
let mut body_lines = vec![];
let mut footer_lines = vec![];
let mut in_footer = false;
@@ -331,7 +519,7 @@ impl LlmClient {
if footer_markers.iter().any(|m| line.starts_with(m)) {
in_footer = true;
}
if in_footer {
footer_lines.push(*line);
} else {
@@ -401,17 +589,34 @@ pub(crate) fn create_http_client(timeout: Duration) -> Result<reqwest::Client> {
}
/// Get commit system prompt based on format and language
fn get_commit_system_prompt(format: crate::config::CommitFormat, language: Language) -> &'static str {
fn get_commit_system_prompt(
format: crate::config::CommitFormat,
language: Language,
) -> &'static str {
match (format, language) {
(crate::config::CommitFormat::Conventional, Language::Chinese) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ZH,
(crate::config::CommitFormat::Conventional, Language::Japanese) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_JA,
(crate::config::CommitFormat::Conventional, Language::Korean) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_KO,
(crate::config::CommitFormat::Conventional, Language::Spanish) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ES,
(crate::config::CommitFormat::Conventional, Language::French) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_FR,
(crate::config::CommitFormat::Conventional, Language::German) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_DE,
(crate::config::CommitFormat::Conventional, Language::Chinese) => {
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ZH
}
(crate::config::CommitFormat::Conventional, Language::Japanese) => {
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_JA
}
(crate::config::CommitFormat::Conventional, Language::Korean) => {
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_KO
}
(crate::config::CommitFormat::Conventional, Language::Spanish) => {
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ES
}
(crate::config::CommitFormat::Conventional, Language::French) => {
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_FR
}
(crate::config::CommitFormat::Conventional, Language::German) => {
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_DE
}
(crate::config::CommitFormat::Conventional, _) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT,
(crate::config::CommitFormat::Commitlint, Language::Chinese) => COMMITLINT_SYSTEM_PROMPT_ZH,
(crate::config::CommitFormat::Commitlint, Language::Japanese) => COMMITLINT_SYSTEM_PROMPT_JA,
(crate::config::CommitFormat::Commitlint, Language::Japanese) => {
COMMITLINT_SYSTEM_PROMPT_JA
}
(crate::config::CommitFormat::Commitlint, Language::Korean) => COMMITLINT_SYSTEM_PROMPT_KO,
(crate::config::CommitFormat::Commitlint, Language::Spanish) => COMMITLINT_SYSTEM_PROMPT_ES,
(crate::config::CommitFormat::Commitlint, Language::French) => COMMITLINT_SYSTEM_PROMPT_FR,
@@ -502,8 +707,7 @@ const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ZH: &str = r#"你是一个生成符合 C
4. 不要大写首字母
5. 结尾不要句号
6. 如果更改特定于模块/组件,请包含作用域
仅输出提交消息,不要输出其他内容。
7. 仅输出提交消息,不要输出其他内容。
"#;
const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_JA: &str = r#"あなたはConventional Commits仕様に従ったコミットメッセージを生成するアシスタントです。
@@ -532,8 +736,7 @@ const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_JA: &str = r#"あなたはConventional C
4. 先頭を大文字にしない
5. 最後にピリオドを付けない
6. 変更がモジュール/コンポーネントに固有の場合はスコープを含める
コミットメッセージのみを出力し、それ以外は出力しないでください。
7. コミットメッセージのみを出力し、それ以外は出力しないでください。
"#;
const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_KO: &str = r#"당신은 Conventional Commits 사양에 따른 커밋 메시지를 생성하는 도우미입니다.
@@ -562,8 +765,7 @@ const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_KO: &str = r#"당신은 Conventional Com
4. 첫 글자 대문자화하지 않음
5. 끝에 마침표 사용하지 않음
6. 변경 사항이 모듈/구성 요소에 특정한 경우 범위 포함
커밋 메시지만 출력하고 다른 내용은 출력하지 마세요.
7. 커밋 메시지만 출력하고 다른 내용은 출력하지 마세요.
"#;
const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ES: &str = r#"Eres un asistente que genera mensajes de commit siguiendo la especificación Conventional Commits.
@@ -592,8 +794,7 @@ Reglas:
4. No capitalices la primera letra
5. Sin punto al final
6. Incluye alcance si el cambio es específico de un módulo/componente
Genera SOLO el mensaje de commit, nada más.
7. Genera SOLO el mensaje de commit, nada más.
"#;
const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_FR: &str = r#"Vous êtes un assistant qui génère des messages de commit suivant la spécification Conventional Commits.
@@ -622,8 +823,7 @@ Règles:
4. Ne capitalisez pas la première lettre
5. Pas de point à la fin
6. Incluez la portée si le changement est spécifique à un module/composant
Générez SEULEMENT le message de commit, rien d'autre.
7. Générez SEULEMENT le message de commit, rien d'autre.
"#;
const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_DE: &str = r#"Sie sind ein Assistent, der Commit-Nachrichten gemäß der Conventional Commits-Spezifikation generiert.
@@ -652,8 +852,7 @@ Regeln:
4. Großschreiben Sie den ersten Buchstaben nicht
5. Kein Punkt am Ende
6. Fügen Sie einen Bereich ein, wenn die Änderung spezifisch für ein Modul/Komponente ist
Geben Sie NUR die Commit-Nachricht aus, nichts anderes.
7. Geben Sie NUR die Commit-Nachricht aus, nichts anderes.
"#;
const COMMITLINT_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates commit messages following @commitlint/config-conventional.
@@ -670,8 +869,7 @@ Rules:
3. Subject should be 4-100 characters
4. Use imperative mood
5. Be concise but descriptive
Output ONLY the commit message, nothing else.
6. Output ONLY the commit message, nothing else.
"#;
const COMMITLINT_SYSTEM_PROMPT_ZH: &str = r#"你是一个生成符合 @commitlint/config-conventional 规范的提交消息的助手。
@@ -688,8 +886,7 @@ const COMMITLINT_SYSTEM_PROMPT_ZH: &str = r#"你是一个生成符合 @commitlin
3. 主题应为 4-100 个字符
4. 使用祈使语气
5. 简洁但描述性强
仅输出提交消息,不要输出其他内容。
6. 仅输出提交消息,不要输出其他额外内容。
"#;
const COMMITLINT_SYSTEM_PROMPT_JA: &str = r#"あなたは@commitlint/config-conventionalに従ったコミットメッセージを生成するアシスタントです。
@@ -706,8 +903,7 @@ git diffを分析し、コミットメッセージを生成してください。
3. 件名は4-100文字である必要があります
4. 命令形を使用してください
5. 簡潔ですが説明的であること
コミットメッセージのみを出力し、それ以外は出力しないでください。
6. コミットメッセージのみを出力し、それ以外は出力しないでください。
"#;
const COMMITLINT_SYSTEM_PROMPT_KO: &str = r#"당신은 @commitlint/config-conventional에 따른 커밋 메시지를 생성하는 도우미입니다.
@@ -724,8 +920,7 @@ git diff를 분석하고 커밋 메시지를 생성하세요.
3. 제목은 4-100자여야 합니다
4. 명령형을 사용하세요
5. 간결하지만 설명적이어야 합니다
커밋 메시지만 출력하고 다른 내용은 출력하지 마세요.
6. 커밋 메시지만 출력하고 다른 내용은 출력하지 마세요.
"#;
const COMMITLINT_SYSTEM_PROMPT_ES: &str = r#"Eres un asistente que genera mensajes de commit siguiendo @commitlint/config-conventional.
@@ -742,8 +937,7 @@ Reglas:
3. El asunto debe tener 4-100 caracteres
4. Usa modo imperativo
5. Sé conciso pero descriptivo
Genera SOLO el mensaje de commit, nada más.
6. Genera SOLO el mensaje de commit, nada más.
"#;
const COMMITLINT_SYSTEM_PROMPT_FR: &str = r#"Vous êtes un assistant qui génère des messages de commit suivant @commitlint/config-conventional.
@@ -760,8 +954,7 @@ Règles:
3. Le sujet doit avoir 4-100 caractères
4. Utilisez le mode impératif
5. Soyez concis mais descriptif
Générez SEULEMENT le message de commit, rien d'autre.
6. Générez SEULEMENT le message de commit, rien d'autre.
"#;
const COMMITLINT_SYSTEM_PROMPT_DE: &str = r#"Sie sind ein Assistent, der Commit-Nachrichten gemäß @commitlint/config-conventional generiert.
@@ -778,8 +971,7 @@ Regeln:
3. Der Betreff sollte 4-100 Zeichen haben
4. Verwenden Sie den Imperativ
5. Seien Sie prägnant aber beschreibend
Geben Sie NUR die Commit-Nachricht aus, nichts anderes.
6. Geben Sie NUR die Commit-Nachricht aus, nichts anderes.
"#;
const TAG_MESSAGE_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates git tag annotation messages.
@@ -1012,3 +1204,10 @@ Gruppieren Sie Commits nach:
Formatieren Sie in Markdown mit geeigneten Überschriften und Aufzählungspunkten.
"#;
/// Test LLM connection
pub async fn test_connection(manager: &crate::config::manager::ConfigManager) -> Result<String> {
let client = LlmClient::from_config(manager).await?;
let response = client.provider.generate("Say 'Hello, World!'").await?;
Ok(response)
}

View File

@@ -1,4 +1,4 @@
use super::{create_http_client, LlmProvider};
use super::{LlmProvider, create_http_client};
use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
@@ -9,6 +9,9 @@ pub struct OllamaClient {
base_url: String,
model: String,
client: reqwest::Client,
max_tokens: u32,
temperature: f32,
top_p: Option<f32>,
}
#[derive(Debug, Serialize)]
@@ -47,69 +50,88 @@ struct ModelInfo {
impl OllamaClient {
/// Create new Ollama client
pub fn new(base_url: &str, model: &str) -> Self {
let client = create_http_client(Duration::from_secs(120))
.expect("Failed to create HTTP client");
let client =
create_http_client(Duration::from_secs(120)).expect("Failed to create HTTP client");
Self {
base_url: base_url.trim_end_matches('/').to_string(),
model: model.to_string(),
client,
max_tokens: 500,
temperature: 0.7,
top_p: None,
}
}
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.client = create_http_client(timeout)
.expect("Failed to create HTTP client");
self.client = create_http_client(timeout).expect("Failed to create HTTP client");
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
/// List available models
pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/api/tags", self.base_url);
let response = self.client
let response = self
.client
.get(&url)
.send()
.await
.context("Failed to list Ollama models")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
anyhow::bail!("Ollama API error: {} - {}", status, text);
}
let result: ListModelsResponse = response
.json()
.await
.context("Failed to parse Ollama response")?;
Ok(result.models.into_iter().map(|m| m.name).collect())
}
/// Pull a model
pub async fn pull_model(&self, model: &str) -> Result<()> {
let url = format!("{}/api/pull", self.base_url);
let request = serde_json::json!({
"name": model,
"stream": false,
});
let response = self.client
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.context("Failed to pull Ollama model")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
anyhow::bail!("Ollama pull error: {} - {}", status, text);
}
Ok(())
}
@@ -130,48 +152,49 @@ impl LlmProvider for OllamaClient {
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let url = format!("{}/api/generate", self.base_url);
let system = if system.is_empty() {
None
} else {
Some(system.to_string())
};
let request = GenerateRequest {
model: self.model.clone(),
prompt: user.to_string(),
system,
stream: false,
options: GenerationOptions {
temperature: Some(0.7),
num_predict: Some(500),
temperature: Some(self.temperature),
num_predict: Some(self.max_tokens),
},
};
let response = self.client
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.context("Failed to send request to Ollama")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
anyhow::bail!("Ollama API error: {} - {}", status, text);
}
let result: GenerateResponse = response
.json()
.await
.context("Failed to parse Ollama response")?;
Ok(result.response.trim().to_string())
}
async fn is_available(&self) -> bool {
let url = format!("{}/api/tags", self.base_url);
match self.client.get(&url).send().await {
Ok(response) => response.status().is_success(),
Err(_) => false,

View File

@@ -1,15 +1,23 @@
use super::{create_http_client, LlmProvider};
use anyhow::{bail, Context, Result};
use super::thinking::ThinkingStateManager;
use super::{LlmProvider, create_http_client};
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
/// OpenAI API client
/// OpenAI API client with o-series reasoning support
pub struct OpenAiClient {
base_url: String,
api_key: String,
model: String,
client: reqwest::Client,
thinking_enabled: bool,
reasoning_effort: Option<String>,
max_tokens: u32,
temperature: f32,
top_p: Option<f32>,
thinking_state: Option<Arc<ThinkingStateManager>>,
}
#[derive(Debug, Serialize)]
@@ -20,10 +28,14 @@ struct ChatCompletionRequest {
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_effort: Option<String>,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
struct Message {
role: String,
content: String,
@@ -39,6 +51,28 @@ struct Choice {
message: Message,
}
// --- Streaming response structures ---
#[derive(Debug, Deserialize)]
struct StreamChunk {
choices: Vec<StreamChoice>,
}
#[derive(Debug, Deserialize)]
struct StreamChoice {
delta: StreamDelta,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize, Default)]
struct StreamDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: ApiError,
@@ -55,57 +89,91 @@ impl OpenAiClient {
/// Create new OpenAI client
pub fn new(base_url: &str, api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
thinking_enabled: false,
reasoning_effort: None,
max_tokens: 500,
temperature: 0.7,
top_p: None,
thinking_state: None,
})
}
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?;
Ok(self)
}
/// List available models
pub fn with_thinking(mut self, enabled: bool) -> Self {
self.thinking_enabled = enabled;
self
}
pub fn with_reasoning_effort(mut self, effort: Option<String>) -> Self {
self.reasoning_effort = effort;
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
self.thinking_state = Some(state);
self
}
pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/models", self.base_url);
let response = self.client
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
.context("Failed to list OpenAI models")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
bail!("OpenAI API error: {} - {}", status, text);
}
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<Model>,
}
#[derive(Deserialize)]
struct Model {
id: String,
}
let result: ModelsResponse = response
.json()
.await
.context("Failed to parse OpenAI response")?;
Ok(result.data.into_iter().map(|m| m.id).collect())
}
/// Validate API key
pub async fn validate_key(&self) -> Result<bool> {
match self.list_models().await {
Ok(_) => Ok(true),
@@ -124,32 +192,30 @@ impl OpenAiClient {
#[async_trait]
impl LlmProvider for OpenAiClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![
Message {
role: "user".to_string(),
content: prompt.to_string(),
},
];
self.chat_completion(messages).await
let messages = vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
}];
self.chat_completion_with_retry(messages).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![];
if !system.is_empty() {
messages.push(Message {
role: "system".to_string(),
content: system.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: user.to_string(),
});
self.chat_completion(messages).await
self.chat_completion_with_retry(messages).await
}
async fn is_available(&self) -> bool {
@@ -162,18 +228,63 @@ impl LlmProvider for OpenAiClient {
}
impl OpenAiClient {
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
let mut last_error = None;
for attempt in 1..=3 {
match self.chat_completion(messages.clone()).await {
Ok(result) => return Ok(result),
Err(e) => {
let err_msg = e.to_string();
let is_retryable = err_msg.contains("timeout")
|| err_msg.contains("connection")
|| err_msg.contains("temporary")
|| err_msg.contains("5")
&& (err_msg.contains("500")
|| err_msg.contains("502")
|| err_msg.contains("503")
|| err_msg.contains("504"));
if !is_retryable || attempt == 3 {
last_error = Some(e);
break;
}
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
}
}
}
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
}
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
if self.thinking_enabled {
self.streaming_chat_completion(messages).await
} else {
self.non_streaming_chat_completion(messages).await
}
}
async fn non_streaming_chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
let request = ChatCompletionRequest {
model: self.model.clone(),
messages,
max_tokens: Some(500),
temperature: Some(0.7),
max_tokens: Some(self.max_tokens),
temperature: Some(self.temperature),
top_p: self.top_p,
reasoning_effort: if is_reasoning_model(&self.model) {
Some("none".to_string())
} else {
None
},
stream: false,
};
let response = self.client
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
@@ -181,31 +292,166 @@ impl OpenAiClient {
.send()
.await
.context("Failed to send request to OpenAI")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
// Try to parse error
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!("OpenAI API error: {} ({})", error.error.message, error.error.error_type);
bail!(
"OpenAI API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("OpenAI API error: {} - {}", status, text);
}
let result: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse OpenAI response")?;
result.choices
result
.choices
.into_iter()
.next()
.map(|c| c.message.content.trim().to_string())
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
}
/// Streaming request for reasoning mode, filters reasoning_content from output
async fn streaming_chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
// For reasoning/thinking mode, omit temperature and top_p
let request = ChatCompletionRequest {
model: self.model.clone(),
messages,
max_tokens: Some(self.max_tokens),
temperature: None,
top_p: None,
reasoning_effort: self.reasoning_effort.clone(),
stream: true,
};
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.json(&request)
.send()
.await
.context("Failed to send streaming request to OpenAI")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"OpenAI API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("OpenAI API error: {} - {}", status, text);
}
let mut content_buffer = String::new();
let mut has_reasoning = false;
let mut has_content = false;
let thinking_state = self.thinking_state.as_ref();
let mut byte_stream = response.bytes_stream();
let mut line_buffer = String::new();
use futures_util::StreamExt;
while let Some(chunk) = byte_stream.next().await {
let chunk = chunk.context("Failed to read streaming response chunk")?;
let chunk_str =
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
line_buffer.push_str(&chunk_str);
while let Some(line_end) = line_buffer.find('\n') {
let line = line_buffer[..line_end].trim().to_string();
line_buffer = line_buffer[line_end + 1..].to_string();
if line.is_empty() {
continue;
}
if line == "data: [DONE]" {
break;
}
if let Some(json_str) = line.strip_prefix("data: ") {
if let Ok(chunk) = serde_json::from_str::<StreamChunk>(json_str) {
for choice in &chunk.choices {
// Handle reasoning_content (o-series)
if let Some(ref reasoning) = choice.delta.reasoning_content
&& !reasoning.is_empty()
{
if !has_reasoning {
has_reasoning = true;
if let Some(state) = thinking_state {
state.start_thinking();
}
}
continue;
}
// Handle content
if let Some(ref content) = choice.delta.content
&& !content.is_empty()
{
if has_reasoning
&& !has_content
&& let Some(state) = thinking_state
{
state.end_thinking();
}
has_content = true;
content_buffer.push_str(content);
}
}
}
}
}
}
if let Some(state) = thinking_state {
state.end_thinking();
}
let result = content_buffer.trim().to_string();
if result.is_empty() {
if has_reasoning && !has_content {
bail!(
"OpenAI returned reasoning content but no final answer. \
The model may have entered an incomplete reasoning state. \
Please try again or disable thinking mode."
);
}
bail!(
"No response from OpenAI. \
If thinking mode is enabled, try disabling it or ensure the model supports reasoning."
);
}
Ok(result)
}
}
/// Azure OpenAI client (extends OpenAI with Azure-specific config)
@@ -215,24 +461,30 @@ pub struct AzureOpenAiClient {
deployment: String,
api_version: String,
client: reqwest::Client,
thinking_enabled: bool,
reasoning_effort: Option<String>,
max_tokens: u32,
temperature: f32,
top_p: Option<f32>,
thinking_state: Option<Arc<ThinkingStateManager>>,
}
impl AzureOpenAiClient {
/// Create new Azure OpenAI client
pub fn new(
endpoint: &str,
api_key: &str,
deployment: &str,
api_version: &str,
) -> Result<Self> {
pub fn new(endpoint: &str, api_key: &str, deployment: &str, api_version: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
Ok(Self {
endpoint: endpoint.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
deployment: deployment.to_string(),
api_version: api_version.to_string(),
client,
thinking_enabled: false,
reasoning_effort: None,
max_tokens: 500,
temperature: 0.7,
top_p: None,
thinking_state: None,
})
}
@@ -241,16 +493,19 @@ impl AzureOpenAiClient {
"{}/openai/deployments/{}/chat/completions?api-version={}",
self.endpoint, self.deployment, self.api_version
);
let request = ChatCompletionRequest {
model: self.deployment.clone(),
messages,
max_tokens: Some(500),
temperature: Some(0.7),
max_tokens: Some(self.max_tokens),
temperature: Some(self.temperature),
top_p: self.top_p,
reasoning_effort: self.reasoning_effort.clone(),
stream: false,
};
let response = self.client
let response = self
.client
.post(&url)
.header("api-key", &self.api_key)
.header("Content-Type", "application/json")
@@ -258,22 +513,24 @@ impl AzureOpenAiClient {
.send()
.await
.context("Failed to send request to Azure OpenAI")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
bail!("Azure OpenAI API error: {} - {}", status, text);
}
let result: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse Azure OpenAI response")?;
result.choices
result
.choices
.into_iter()
.next()
.map(|c| c.message.content.trim().to_string())
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))
}
}
@@ -281,41 +538,38 @@ impl AzureOpenAiClient {
#[async_trait]
impl LlmProvider for AzureOpenAiClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![
Message {
role: "user".to_string(),
content: prompt.to_string(),
},
];
let messages = vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
}];
self.chat_completion(messages).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![];
if !system.is_empty() {
messages.push(Message {
role: "system".to_string(),
content: system.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: user.to_string(),
});
self.chat_completion(messages).await
}
async fn is_available(&self) -> bool {
// Simple check - try to make a minimal request
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
self.endpoint, self.deployment, self.api_version
);
let request = ChatCompletionRequest {
model: self.deployment.clone(),
messages: vec![Message {
@@ -324,10 +578,13 @@ impl LlmProvider for AzureOpenAiClient {
}],
max_tokens: Some(5),
temperature: Some(0.0),
top_p: None,
reasoning_effort: None,
stream: false,
};
match self.client
match self
.client
.post(&url)
.header("api-key", &self.api_key)
.json(&request)
@@ -343,3 +600,60 @@ impl LlmProvider for AzureOpenAiClient {
"azure-openai"
}
}
/// Available OpenAI models (including o-series with reasoning)
pub const OPENAI_MODELS: &[&str] = &[
"o4-mini",
"o3",
"o3-mini",
"o1",
"o1-mini",
"o1-pro",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4",
"gpt-3.5-turbo",
];
pub fn is_valid_model(model: &str) -> bool {
OPENAI_MODELS.contains(&model)
}
fn is_reasoning_model(model: &str) -> bool {
model.starts_with("o")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_validation_o_series() {
assert!(is_valid_model("o4-mini"));
assert!(is_valid_model("o3"));
assert!(is_valid_model("o1"));
assert!(is_valid_model("gpt-4o"));
assert!(is_valid_model("gpt-3.5-turbo"));
assert!(!is_valid_model("invalid-model"));
}
#[test]
fn test_stream_delta_reasoning_parsing() {
let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#;
let delta: StreamDelta = serde_json::from_str(json).unwrap();
assert!(delta.content.is_none());
assert_eq!(delta.reasoning_content, Some("Let me think...".to_string()));
}
#[test]
fn test_stream_delta_content_parsing() {
let json = r#"{"content":"Hello","reasoning_content":null}"#;
let delta: StreamDelta = serde_json::from_str(json).unwrap();
assert_eq!(delta.content, Some("Hello".to_string()));
assert!(delta.reasoning_content.is_none());
}
}

View File

@@ -1,257 +1,286 @@
use super::{create_http_client, LlmProvider};
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Duration;
/// OpenRouter API client
pub struct OpenRouterClient {
base_url: String,
api_key: String,
model: String,
client: reqwest::Client,
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: Message,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: ApiError,
}
#[derive(Debug, Deserialize)]
struct ApiError {
message: String,
#[serde(rename = "type")]
error_type: String,
}
impl OpenRouterClient {
/// Create new OpenRouter client
pub fn new(api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
Ok(Self {
base_url: "https://openrouter.ai/api/v1".to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
})
}
/// Create with custom base URL
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
})
}
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?;
Ok(self)
}
/// List available models
pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/models", self.base_url);
let response = self.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("HTTP-Referer", "https://quicommit.dev")
.header("X-Title", "QuiCommit")
.send()
.await
.context("Failed to list OpenRouter models")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
bail!("OpenRouter API error: {} - {}", status, text);
}
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<Model>,
}
#[derive(Deserialize)]
struct Model {
id: String,
}
let result: ModelsResponse = response
.json()
.await
.context("Failed to parse OpenRouter response")?;
Ok(result.data.into_iter().map(|m| m.id).collect())
}
/// Validate API key
pub async fn validate_key(&self) -> Result<bool> {
match self.list_models().await {
Ok(_) => Ok(true),
Err(e) => {
let err_str = e.to_string();
if err_str.contains("401") || err_str.contains("Unauthorized") {
Ok(false)
} else {
Err(e)
}
}
}
}
}
#[async_trait]
impl LlmProvider for OpenRouterClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![
Message {
role: "user".to_string(),
content: prompt.to_string(),
},
];
self.chat_completion(messages).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![];
if !system.is_empty() {
messages.push(Message {
role: "system".to_string(),
content: system.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: user.to_string(),
});
self.chat_completion(messages).await
}
async fn is_available(&self) -> bool {
self.validate_key().await.unwrap_or(false)
}
fn name(&self) -> &str {
"openrouter"
}
}
impl OpenRouterClient {
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
let request = ChatCompletionRequest {
model: self.model.clone(),
messages,
max_tokens: Some(500),
temperature: Some(0.7),
stream: false,
};
let response = self.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("HTTP-Referer", "https://quicommit.dev")
.header("X-Title", "QuiCommit")
.json(&request)
.send()
.await
.context("Failed to send request to OpenRouter")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
// Try to parse error
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!("OpenRouter API error: {} ({})", error.error.message, error.error.error_type);
}
bail!("OpenRouter API error: {} - {}", status, text);
}
let result: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse OpenRouter response")?;
result.choices
.into_iter()
.next()
.map(|c| c.message.content.trim().to_string())
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
}
}
/// Popular OpenRouter models
pub const OPENROUTER_MODELS: &[&str] = &[
"openai/gpt-3.5-turbo",
"openai/gpt-4",
"openai/gpt-4-turbo",
"anthropic/claude-3-opus",
"anthropic/claude-3-sonnet",
"anthropic/claude-3-haiku",
"google/gemini-pro",
"meta-llama/llama-2-70b-chat",
"mistralai/mixtral-8x7b-instruct",
"01-ai/yi-34b-chat",
];
/// Check if a model name is valid
pub fn is_valid_model(_model: &str) -> bool {
// Since OpenRouter supports many models, we'll allow any model name
// but provide some popular ones as suggestions
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_validation() {
assert!(is_valid_model("openai/gpt-4"));
assert!(is_valid_model("custom/model"));
}
}
use super::{LlmProvider, create_http_client};
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Duration;
/// OpenRouter API client
pub struct OpenRouterClient {
base_url: String,
api_key: String,
model: String,
client: reqwest::Client,
max_tokens: u32,
temperature: f32,
top_p: Option<f32>,
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: Message,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: ApiError,
}
#[derive(Debug, Deserialize)]
struct ApiError {
message: String,
#[serde(rename = "type")]
error_type: String,
}
impl OpenRouterClient {
/// Create new OpenRouter client
pub fn new(api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
Ok(Self {
base_url: "https://openrouter.ai/api/v1".to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
max_tokens: 500,
temperature: 0.7,
top_p: None,
})
}
/// Create with custom base URL
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
max_tokens: 500,
temperature: 0.7,
top_p: None,
})
}
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?;
Ok(self)
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
/// List available models
pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/models", self.base_url);
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("HTTP-Referer", "https://quicommit.dev")
.header("X-Title", "QuiCommit")
.send()
.await
.context("Failed to list OpenRouter models")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
bail!("OpenRouter API error: {} - {}", status, text);
}
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<Model>,
}
#[derive(Deserialize)]
struct Model {
id: String,
}
let result: ModelsResponse = response
.json()
.await
.context("Failed to parse OpenRouter response")?;
Ok(result.data.into_iter().map(|m| m.id).collect())
}
/// Validate API key
pub async fn validate_key(&self) -> Result<bool> {
match self.list_models().await {
Ok(_) => Ok(true),
Err(e) => {
let err_str = e.to_string();
if err_str.contains("401") || err_str.contains("Unauthorized") {
Ok(false)
} else {
Err(e)
}
}
}
}
}
#[async_trait]
impl LlmProvider for OpenRouterClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
}];
self.chat_completion(messages).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![];
if !system.is_empty() {
messages.push(Message {
role: "system".to_string(),
content: system.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: user.to_string(),
});
self.chat_completion(messages).await
}
async fn is_available(&self) -> bool {
self.validate_key().await.unwrap_or(false)
}
fn name(&self) -> &str {
"openrouter"
}
}
impl OpenRouterClient {
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
let request = ChatCompletionRequest {
model: self.model.clone(),
messages,
max_tokens: Some(self.max_tokens),
temperature: Some(self.temperature),
stream: false,
};
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("HTTP-Referer", "https://quicommit.dev")
.header("X-Title", "QuiCommit")
.json(&request)
.send()
.await
.context("Failed to send request to OpenRouter")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
// Try to parse error
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"OpenRouter API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("OpenRouter API error: {} - {}", status, text);
}
let result: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse OpenRouter response")?;
result
.choices
.into_iter()
.next()
.map(|c| c.message.content.trim().to_string())
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
}
}
/// Popular OpenRouter models
pub const OPENROUTER_MODELS: &[&str] = &[
"openai/gpt-3.5-turbo",
"openai/gpt-4",
"openai/gpt-4-turbo",
"anthropic/claude-3-opus",
"anthropic/claude-3-sonnet",
"anthropic/claude-3-haiku",
"google/gemini-pro",
"meta-llama/llama-2-70b-chat",
"mistralai/mixtral-8x7b-instruct",
"01-ai/yi-34b-chat",
];
/// Check if a model name is valid
pub fn is_valid_model(_model: &str) -> bool {
// Since OpenRouter supports many models, we'll allow any model name
// but provide some popular ones as suggestions
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_validation() {
assert!(is_valid_model("openai/gpt-4"));
assert!(is_valid_model("custom/model"));
}
}

151
src/llm/thinking.rs Normal file
View File

@@ -0,0 +1,151 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
/// 统一的思考状态管理器,用于管理模型思考状态的显示与隐藏
pub struct ThinkingStateManager {
is_thinking: AtomicBool,
on_start: Option<Box<dyn Fn() + Send + Sync>>,
on_end: Option<Box<dyn Fn() + Send + Sync>>,
}
impl ThinkingStateManager {
pub fn new() -> Self {
Self {
is_thinking: AtomicBool::new(false),
on_start: None,
on_end: None,
}
}
/// 设置思考开始回调
pub fn on_thinking_start<F: Fn() + Send + Sync + 'static>(mut self, callback: F) -> Self {
self.on_start = Some(Box::new(callback));
self
}
/// 设置思考结束回调
pub fn on_thinking_end<F: Fn() + Send + Sync + 'static>(mut self, callback: F) -> Self {
self.on_end = Some(Box::new(callback));
self
}
/// 开始思考状态
pub fn start_thinking(&self) {
if !self.is_thinking.load(Ordering::SeqCst) {
self.is_thinking.store(true, Ordering::SeqCst);
if let Some(ref cb) = self.on_start {
cb();
}
}
}
/// 结束思考状态
pub fn end_thinking(&self) {
if self.is_thinking.load(Ordering::SeqCst) {
self.is_thinking.store(false, Ordering::SeqCst);
if let Some(ref cb) = self.on_end {
cb();
}
}
}
/// 当前是否处于思考状态
pub fn is_thinking(&self) -> bool {
self.is_thinking.load(Ordering::SeqCst)
}
}
impl Default for ThinkingStateManager {
fn default() -> Self {
Self::new()
}
}
/// 线程安全的思考状态管理器引用
pub type SharedThinkingState = Arc<ThinkingStateManager>;
/// 创建带有默认控制台输出的思考状态管理器
/// 在思考开始时打印 "thinking...",在思考结束时清除该标识
pub fn create_console_thinking_state() -> SharedThinkingState {
Arc::new(
ThinkingStateManager::new()
.on_thinking_start(|| {
eprint!("\rthinking...");
})
.on_thinking_end(|| {
eprint!("\r \r");
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[test]
fn test_thinking_state_transitions() {
let manager = ThinkingStateManager::new();
assert!(!manager.is_thinking());
manager.start_thinking();
assert!(manager.is_thinking());
manager.end_thinking();
assert!(!manager.is_thinking());
}
#[test]
fn test_thinking_idempotent_start() {
let manager = ThinkingStateManager::new();
manager.start_thinking();
manager.start_thinking(); // 重复调用不应触发回调两次
assert!(manager.is_thinking());
}
#[test]
fn test_thinking_idempotent_end() {
let manager = ThinkingStateManager::new();
manager.end_thinking(); // 未开始时结束不应触发问题
assert!(!manager.is_thinking());
}
#[test]
fn test_thinking_callbacks() {
let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
let manager = ThinkingStateManager::new().on_thinking_start(move || {
events_clone.lock().unwrap().push("start".to_string());
});
let events_clone2 = events.clone();
let manager = manager.on_thinking_end(move || {
events_clone2.lock().unwrap().push("end".to_string());
});
manager.start_thinking();
manager.end_thinking();
let recorded = events.lock().unwrap();
assert_eq!(recorded.len(), 2);
assert_eq!(recorded[0], "start");
assert_eq!(recorded[1], "end");
}
#[test]
fn test_create_console_thinking_state() {
let state = create_console_thinking_state();
assert!(!state.is_thinking());
state.start_thinking();
assert!(state.is_thinking());
state.end_thinking();
assert!(!state.is_thinking());
}
#[test]
fn test_default() {
let manager = ThinkingStateManager::default();
assert!(!manager.is_thinking());
}
}

View File

@@ -1,3 +1,5 @@
#![allow(dead_code)]
use anyhow::Result;
use clap::{Parser, Subcommand};
use std::path::PathBuf;
@@ -12,12 +14,12 @@ mod llm;
mod utils;
use commands::{
changelog::ChangelogCommand, commit::CommitCommand, config::ConfigCommand,
init::InitCommand, profile::ProfileCommand, tag::TagCommand,
changelog::ChangelogCommand, commit::CommitCommand, config::ConfigCommand, init::InitCommand,
profile::ProfileCommand, tag::TagCommand,
};
/// QuiCommit - AI-powered Git assistant
///
///
/// A powerful tool that helps you generate conventional commits, tags, and changelogs
/// using AI (LLM APIs or local Ollama models). Manage multiple Git profiles for different
/// work contexts seamlessly.
@@ -81,7 +83,7 @@ async fn main() -> Result<()> {
2 => "debug",
_ => "trace",
};
tracing_subscriber::fmt()
.with_env_filter(log_level)
.with_target(false)

View File

@@ -1,9 +1,9 @@
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
aead::{Aead, KeyInit},
};
use anyhow::{Context, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
use rand::Rng;
use std::fs;
use std::path::Path;
@@ -18,63 +18,62 @@ pub fn encrypt(data: &[u8], password: &str) -> Result<String> {
rand::thread_rng().fill(&mut salt);
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::thread_rng().fill(&mut nonce_bytes);
let key = derive_key(password, &salt)?;
let cipher = Aes256Gcm::new_from_slice(&key)
.context("Failed to create cipher")?;
let cipher = Aes256Gcm::new_from_slice(&key).context("Failed to create cipher")?;
let nonce = Nonce::from_slice(&nonce_bytes);
let encrypted = cipher
.encrypt(nonce, data)
.map_err(|e| anyhow::anyhow!("Encryption failed: {:?}", e))?;
// Combine salt + nonce + encrypted data
let mut result = Vec::with_capacity(SALT_LEN + NONCE_LEN + encrypted.len());
result.extend_from_slice(&salt);
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&encrypted);
Ok(BASE64.encode(&result))
}
/// Decrypt data with password
pub fn decrypt(encrypted_data: &str, password: &str) -> Result<Vec<u8>> {
let data = BASE64.decode(encrypted_data)
let data = BASE64
.decode(encrypted_data)
.context("Invalid base64 encoding")?;
if data.len() < SALT_LEN + NONCE_LEN {
anyhow::bail!("Invalid encrypted data format");
}
let salt = &data[..SALT_LEN];
let nonce_bytes = &data[SALT_LEN..SALT_LEN + NONCE_LEN];
let encrypted = &data[SALT_LEN + NONCE_LEN..];
let key = derive_key(password, salt)?;
let cipher = Aes256Gcm::new_from_slice(&key)
.context("Failed to create cipher")?;
let cipher = Aes256Gcm::new_from_slice(&key).context("Failed to create cipher")?;
let nonce = Nonce::from_slice(nonce_bytes);
let decrypted = cipher
.decrypt(nonce, encrypted)
.map_err(|e| anyhow::anyhow!("Decryption failed: {:?}", e))?;
Ok(decrypted)
}
/// Derive key from password using simple method
fn derive_key(password: &str, salt: &[u8]) -> Result<[u8; KEY_LEN]> {
use sha2::{Sha256, Digest};
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(salt);
hasher.update(password.as_bytes());
hasher.update(b"quicommit_key_derivation_v1");
let hash = hasher.finalize();
let mut key = [0u8; KEY_LEN];
key.copy_from_slice(&hash[..KEY_LEN]);
Ok(key)
}
@@ -97,7 +96,7 @@ pub fn decrypt_from_file(path: &Path, password: &str) -> Result<Vec<u8>> {
pub fn generate_token(length: usize) -> String {
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
let mut rng = rand::thread_rng();
(0..length)
.map(|_| {
let idx = rng.gen_range(0..CHARSET.len());
@@ -122,10 +121,10 @@ mod tests {
fn test_encrypt_decrypt() {
let data = b"Hello, World!";
let password = "my_secret_password";
let encrypted = encrypt(data, password).unwrap();
let decrypted = decrypt(&encrypted, password).unwrap();
assert_eq!(data.to_vec(), decrypted);
}
@@ -133,7 +132,7 @@ mod tests {
fn test_wrong_password() {
let data = b"Hello, World!";
let encrypted = encrypt(data, "correct_password").unwrap();
assert!(decrypt(&encrypted, "wrong_password").is_err());
}
}

View File

@@ -9,15 +9,12 @@ pub fn edit_content(initial_content: &str) -> Result<String> {
/// Edit file in user's default editor
pub fn edit_file(path: &Path) -> Result<String> {
let content = fs::read_to_string(path)
.unwrap_or_default();
let edited = edit::edit(&content)
.context("Failed to open editor")?;
fs::write(path, &edited)
.with_context(|| format!("Failed to write file: {:?}", path))?;
let content = fs::read_to_string(path).unwrap_or_default();
let edited = edit::edit(&content).context("Failed to open editor")?;
fs::write(path, &edited).with_context(|| format!("Failed to write file: {:?}", path))?;
Ok(edited)
}
@@ -27,11 +24,10 @@ pub fn edit_temp(initial_content: &str, extension: &str) -> Result<String> {
.suffix(&format!(".{}", extension))
.tempfile()
.context("Failed to create temp file")?;
let path = temp_file.path();
fs::write(path, initial_content)
.context("Failed to write temp file")?;
fs::write(path, initial_content).context("Failed to write temp file")?;
edit_file(path)
}
@@ -41,10 +37,10 @@ pub fn get_editor() -> String {
.or_else(|_| std::env::var("VISUAL"))
.unwrap_or_else(|_| {
if cfg!(target_os = "windows") {
if let Ok(code) = which::which("code") {
if let Ok(_code) = which::which("code") {
return "code --wait".to_string();
}
if let Ok(notepad) = which::which("notepad") {
if let Ok(_notepad) = which::which("notepad") {
return "notepad".to_string();
}
"notepad".to_string()
@@ -65,7 +61,6 @@ pub fn get_editor() -> String {
/// Check if editor is available
pub fn check_editor() -> Result<()> {
let editor = get_editor();
which::which(&editor)
.with_context(|| format!("Editor '{}' not found in PATH", editor))?;
which::which(&editor).with_context(|| format!("Editor '{}' not found in PATH", editor))?;
Ok(())
}

View File

@@ -10,7 +10,7 @@ pub fn format_conventional_commit(
breaking: bool,
) -> String {
let mut message = String::new();
message.push_str(commit_type);
if let Some(s) = scope {
message.push_str(&format!("({})", s));
@@ -19,15 +19,15 @@ pub fn format_conventional_commit(
message.push('!');
}
message.push_str(&format!(": {}", description));
if let Some(b) = body {
message.push_str(&format!("\n\n{}", b));
}
if let Some(f) = footer {
message.push_str(&format!("\n\n{}", f));
}
message
}
@@ -41,27 +41,27 @@ pub fn format_commitlint_commit(
references: Option<&[&str]>,
) -> String {
let mut message = String::new();
message.push_str(commit_type);
if let Some(s) = scope {
message.push_str(&format!("({})", s));
}
message.push_str(&format!(": {}", subject));
if let Some(refs) = references {
for reference in refs {
message.push_str(&format!(" #{}", reference));
}
}
if let Some(b) = body {
message.push_str(&format!("\n\n{}", b));
}
if let Some(f) = footer {
message.push_str(&format!("\n\n{}", f));
}
message
}
@@ -73,7 +73,7 @@ pub fn wrap_text(text: &str, width: usize) -> String {
/// Clean commit message (remove comments, extra whitespace)
pub fn clean_message(message: &str) -> String {
let comment_regex = Regex::new(r"^#.*$").unwrap();
message
.lines()
.filter(|line| !comment_regex.is_match(line.trim()))
@@ -97,7 +97,7 @@ mod tests {
Some("Closes #123"),
false,
);
assert!(msg.contains("feat(auth): add login functionality"));
assert!(msg.contains("This adds OAuth2 login support."));
assert!(msg.contains("Closes #123"));
@@ -113,7 +113,7 @@ mod tests {
Some("BREAKING CHANGE: response format changed"),
true,
);
assert!(msg.starts_with("feat!: change API response format"));
}
}

350
src/utils/keyring.rs Normal file
View File

@@ -0,0 +1,350 @@
use anyhow::{Context, Result, bail};
use std::env;
const SERVICE_NAME: &str = "quicommit";
const ENV_API_KEY: &str = "QUICOMMIT_API_KEY";
const PAT_SERVICE_PREFIX: &str = "quicommit/pat";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KeyringStatus {
Available,
Unavailable,
}
pub struct KeyringManager {
status: KeyringStatus,
}
impl KeyringManager {
pub fn new() -> Self {
let status = Self::check_keyring_availability();
Self { status }
}
pub fn check_keyring_availability() -> KeyringStatus {
#[cfg(target_os = "windows")]
{
KeyringStatus::Available
}
#[cfg(target_os = "macos")]
{
KeyringStatus::Available
}
#[cfg(target_os = "linux")]
{
Self::check_linux_keyring()
}
#[cfg(not(any(target_os = "windows", target_os = "macos", target_os = "linux")))]
{
KeyringStatus::Unavailable
}
}
#[cfg(target_os = "linux")]
fn check_linux_keyring() -> KeyringStatus {
use std::path::Path;
let has_dbus = Path::new("/usr/bin/dbus-daemon").exists()
|| Path::new("/bin/dbus-daemon").exists()
|| env::var("DBUS_SESSION_BUS_ADDRESS").is_ok();
let has_keyring = Path::new("/usr/bin/gnome-keyring-daemon").exists()
|| Path::new("/usr/bin/gnome-keyring").exists()
|| Path::new("/usr/bin/kwalletd5").exists()
|| Path::new("/usr/bin/kwalletd6").exists()
|| env::var("SECRET_SERVICE").is_ok();
if has_dbus && has_keyring {
KeyringStatus::Available
} else {
KeyringStatus::Unavailable
}
}
pub fn status(&self) -> KeyringStatus {
self.status
}
pub fn is_available(&self) -> bool {
self.status == KeyringStatus::Available
}
pub fn store_api_key(&self, provider: &str, api_key: &str) -> Result<()> {
if !self.is_available() {
bail!("Keyring is not available on this system");
}
let entry = keyring::Entry::new(SERVICE_NAME, provider)
.context("Failed to create keyring entry")?;
entry
.set_password(api_key)
.context("Failed to store API key")?;
Ok(())
}
pub fn get_api_key(&self, provider: &str) -> Result<Option<String>> {
if let Ok(key) = env::var(ENV_API_KEY)
&& !key.is_empty()
{
return Ok(Some(key));
}
if !self.is_available() {
return Ok(None);
}
let entry = keyring::Entry::new(SERVICE_NAME, provider)
.context("Failed to create keyring entry")?;
match entry.get_password() {
Ok(key) => Ok(Some(key)),
Err(keyring::Error::NoEntry) => Ok(None),
Err(e) => Err(e.into()),
}
}
pub fn delete_api_key(&self, provider: &str) -> Result<()> {
if !self.is_available() {
bail!("Keyring is not available on this system");
}
let entry = keyring::Entry::new(SERVICE_NAME, provider)
.context("Failed to create keyring entry")?;
entry
.delete_credential()
.context("Failed to delete API key")?;
Ok(())
}
pub fn has_api_key(&self, provider: &str) -> bool {
self.get_api_key(provider).unwrap_or(None).is_some()
}
fn make_pat_service_name(profile_name: &str) -> String {
format!("{}/{}", PAT_SERVICE_PREFIX, profile_name)
}
pub fn store_pat(
&self,
profile_name: &str,
user_email: &str,
service: &str,
token: &str,
) -> Result<()> {
if !self.is_available() {
bail!("Keyring is not available on this system");
}
let keyring_service = Self::make_pat_service_name(profile_name);
let keyring_user = format!("{}:{}", user_email, service);
let entry = keyring::Entry::new(&keyring_service, &keyring_user)
.context("Failed to create keyring entry for PAT")?;
entry
.set_password(token)
.context("Failed to store PAT in keyring")?;
eprintln!(
"[DEBUG] PAT stored in keyring: service={}, user={}",
keyring_service, keyring_user
);
Ok(())
}
pub fn get_pat(
&self,
profile_name: &str,
user_email: &str,
service: &str,
) -> Result<Option<String>> {
if !self.is_available() {
return Ok(None);
}
let keyring_service = Self::make_pat_service_name(profile_name);
let keyring_user = format!("{}:{}", user_email, service);
let entry = keyring::Entry::new(&keyring_service, &keyring_user)
.context("Failed to create keyring entry for PAT")?;
match entry.get_password() {
Ok(token) => {
eprintln!(
"[DEBUG] PAT retrieved from keyring: service={}, user={}",
keyring_service, keyring_user
);
Ok(Some(token))
}
Err(keyring::Error::NoEntry) => {
eprintln!(
"[DEBUG] PAT not found in keyring: service={}, user={}",
keyring_service, keyring_user
);
Ok(None)
}
Err(e) => Err(e.into()),
}
}
pub fn delete_pat(&self, profile_name: &str, user_email: &str, service: &str) -> Result<()> {
if !self.is_available() {
bail!("Keyring is not available on this system");
}
let keyring_service = Self::make_pat_service_name(profile_name);
let keyring_user = format!("{}:{}", user_email, service);
let entry = keyring::Entry::new(&keyring_service, &keyring_user)
.context("Failed to create keyring entry for PAT")?;
entry
.delete_credential()
.context("Failed to delete PAT from keyring")?;
eprintln!(
"[DEBUG] PAT deleted from keyring: service={}, user={}",
keyring_service, keyring_user
);
Ok(())
}
pub fn has_pat(&self, profile_name: &str, user_email: &str, service: &str) -> bool {
self.get_pat(profile_name, user_email, service)
.unwrap_or(None)
.is_some()
}
pub fn delete_all_pats_for_profile(
&self,
profile_name: &str,
user_email: &str,
services: &[String],
) -> Result<()> {
for service in services {
if let Err(e) = self.delete_pat(profile_name, user_email, service) {
eprintln!(
"[DEBUG] Failed to delete PAT for service '{}': {}",
service, e
);
}
}
Ok(())
}
pub fn get_status_message(&self) -> String {
match self.status {
KeyringStatus::Available => {
#[cfg(target_os = "windows")]
{
"Windows Credential Manager is available".to_string()
}
#[cfg(target_os = "macos")]
{
"macOS Keychain is available".to_string()
}
#[cfg(target_os = "linux")]
{
"Linux secret service is available".to_string()
}
#[cfg(not(any(target_os = "windows", target_os = "macos", target_os = "linux")))]
{
"Keyring is available".to_string()
}
}
KeyringStatus::Unavailable => {
"Keyring is not available. Set QUICOMMIT_API_KEY environment variable.".to_string()
}
}
}
}
impl Default for KeyringManager {
fn default() -> Self {
Self::new()
}
}
pub fn get_default_base_url(provider: &str) -> &'static str {
match provider {
"openai" => "https://api.openai.com/v1",
"anthropic" => "https://api.anthropic.com/v1",
"kimi" => "https://api.moonshot.cn/v1",
"deepseek" => "https://api.deepseek.com/v1",
"openrouter" => "https://openrouter.ai/api/v1",
"ollama" => "http://localhost:11434",
_ => "",
}
}
pub fn get_default_model(provider: &str) -> &'static str {
match provider {
"openai" => "gpt-4",
"anthropic" => "claude-3-sonnet-20240229",
"kimi" => "kimi-k2.6",
"deepseek" => "deepseek-v4-flash",
"openrouter" => "openai/gpt-3.5-turbo",
"ollama" => "llama2",
_ => "",
}
}
pub fn get_supported_providers() -> &'static [&'static str] {
&[
"ollama",
"openai",
"anthropic",
"kimi",
"deepseek",
"openrouter",
]
}
pub fn provider_needs_api_key(provider: &str) -> bool {
provider != "ollama"
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_default_base_url() {
assert_eq!(get_default_base_url("openai"), "https://api.openai.com/v1");
assert_eq!(
get_default_base_url("anthropic"),
"https://api.anthropic.com/v1"
);
assert_eq!(get_default_base_url("kimi"), "https://api.moonshot.cn/v1");
assert_eq!(
get_default_base_url("deepseek"),
"https://api.deepseek.com/v1"
);
assert_eq!(
get_default_base_url("openrouter"),
"https://openrouter.ai/api/v1"
);
assert_eq!(get_default_base_url("ollama"), "http://localhost:11434");
}
#[test]
fn test_get_default_model() {
assert_eq!(get_default_model("openai"), "gpt-4");
assert_eq!(get_default_model("anthropic"), "claude-3-sonnet-20240229");
assert_eq!(get_default_model("ollama"), "llama2");
}
#[test]
fn test_provider_needs_api_key() {
assert!(provider_needs_api_key("openai"));
assert!(provider_needs_api_key("anthropic"));
assert!(!provider_needs_api_key("ollama"));
}
}

View File

@@ -1,6 +1,7 @@
pub mod crypto;
pub mod editor;
pub mod formatter;
pub mod keyring;
pub mod validators;
use anyhow::{Context, Result};
@@ -31,10 +32,10 @@ pub fn print_info(msg: &str) {
pub fn confirm(prompt: &str) -> Result<bool> {
print!("{} [y/N] ", prompt);
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
Ok(input.trim().to_lowercase().starts_with('y'))
}
@@ -42,17 +43,17 @@ pub fn confirm(prompt: &str) -> Result<bool> {
pub fn input(prompt: &str) -> Result<String> {
print!("{}: ", prompt);
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
Ok(input.trim().to_string())
}
/// Get password input (hidden)
pub fn password_input(prompt: &str) -> Result<String> {
use dialoguer::Password;
Password::new()
.with_prompt(prompt)
.interact()

View File

@@ -1,4 +1,4 @@
use anyhow::{bail, Result};
use anyhow::{Result, bail};
use lazy_static::lazy_static;
use regex::Regex;
@@ -67,7 +67,7 @@ lazy_static! {
/// Validate conventional commit message
pub fn validate_conventional_commit(message: &str) -> Result<()> {
let first_line = message.lines().next().unwrap_or("");
if !CONVENTIONAL_COMMIT_REGEX.is_match(first_line) {
bail!(
"Invalid conventional commit format. Expected: <type>[optional scope]: <description>\n\
@@ -75,32 +75,32 @@ pub fn validate_conventional_commit(message: &str) -> Result<()> {
CONVENTIONAL_TYPES.join(", ")
);
}
if first_line.len() > 100 {
bail!("Commit subject too long (max 100 characters)");
}
Ok(())
}
/// Validate @commitlint commit message
pub fn validate_commitlint_commit(message: &str) -> Result<()> {
let first_line = message.lines().next().unwrap_or("");
let parts: Vec<&str> = first_line.splitn(2, ':').collect();
if parts.len() != 2 {
bail!("Invalid commit format. Expected: <type>[optional scope]: <subject>");
}
let type_part = parts[0];
let subject = parts[1].trim();
let commit_type = type_part
.split('(')
.next()
.unwrap_or("")
.trim_end_matches('!');
if !COMMITLINT_TYPES.contains(&commit_type) {
bail!(
"Invalid commit type: '{}'. Valid types: {}",
@@ -108,27 +108,32 @@ pub fn validate_commitlint_commit(message: &str) -> Result<()> {
COMMITLINT_TYPES.join(", ")
);
}
if subject.is_empty() {
bail!("Commit subject cannot be empty");
}
if subject.len() < 4 {
bail!("Commit subject too short (min 4 characters)");
}
if subject.len() > 100 {
bail!("Commit subject too long (max 100 characters)");
}
if subject.chars().next().map(|c| c.is_uppercase()).unwrap_or(false) {
if subject
.chars()
.next()
.map(|c| c.is_uppercase())
.unwrap_or(false)
{
bail!("Commit subject should not start with uppercase letter");
}
if subject.ends_with('.') {
bail!("Commit subject should not end with a period");
}
Ok(())
}
@@ -137,25 +142,25 @@ pub fn validate_scope(scope: &str) -> Result<()> {
if scope.is_empty() {
bail!("Scope cannot be empty");
}
if !SCOPE_REGEX.is_match(scope) {
bail!("Invalid scope format. Use lowercase letters, numbers, and hyphens only");
}
Ok(())
}
/// Validate semantic version tag
pub fn validate_semver(version: &str) -> Result<()> {
let version = version.trim_start_matches('v');
if !SEMVER_REGEX.is_match(version) {
bail!(
"Invalid semantic version format. Expected: MAJOR.MINOR.PATCH[-prerelease][+build]\n\
Examples: 1.0.0, 1.2.3-beta, v2.0.0+build123"
);
}
Ok(())
}
@@ -164,7 +169,7 @@ pub fn validate_email(email: &str) -> Result<()> {
if !EMAIL_REGEX.is_match(email) {
bail!("Invalid email address format");
}
Ok(())
}
@@ -173,7 +178,7 @@ pub fn validate_gpg_key_id(key_id: &str) -> Result<()> {
if !GPG_KEY_ID_REGEX.is_match(key_id) {
bail!("Invalid GPG key ID format. Expected 16-40 hexadecimal characters");
}
Ok(())
}
@@ -182,15 +187,18 @@ pub fn validate_profile_name(name: &str) -> Result<()> {
if name.is_empty() {
bail!("Profile name cannot be empty");
}
if name.len() > 50 {
bail!("Profile name too long (max 50 characters)");
}
if !name.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_') {
if !name
.chars()
.all(|c| c.is_alphanumeric() || c == '-' || c == '_')
{
bail!("Profile name can only contain letters, numbers, hyphens, and underscores");
}
Ok(())
}
@@ -201,7 +209,7 @@ pub fn is_valid_commit_type(commit_type: &str, use_commitlint: bool) -> bool {
} else {
CONVENTIONAL_TYPES
};
types.contains(&commit_type)
}

View File

@@ -0,0 +1,434 @@
use assert_cmd::cargo::cargo_bin_cmd;
use predicates::prelude::*;
use std::fs;
use std::path::PathBuf;
use tempfile::TempDir;
fn init_quicommit(config_path: &PathBuf) {
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
}
mod config_export {
use super::*;
#[test]
fn test_export_to_stdout() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
init_quicommit(&config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("version"))
.stdout(predicate::str::contains("[llm]"));
}
#[test]
fn test_export_to_file() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let export_path = temp_dir.path().join("exported.toml");
init_quicommit(&config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path.to_str().unwrap(),
"--output",
export_path.to_str().unwrap(),
"--password",
"",
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("Configuration exported"));
assert!(export_path.exists(), "Export file should be created");
let content = fs::read_to_string(&export_path).unwrap();
assert!(content.contains("version"), "Export should contain version");
assert!(
content.contains("[llm]"),
"Export should contain LLM config"
);
}
#[test]
fn test_export_encrypted() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let export_path = temp_dir.path().join("encrypted.toml");
init_quicommit(&config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path.to_str().unwrap(),
"--output",
export_path.to_str().unwrap(),
"--password",
"test_password_123",
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("encrypted and exported"));
assert!(export_path.exists(), "Export file should be created");
let content = fs::read_to_string(&export_path).unwrap();
assert!(
content.starts_with("ENCRYPTED:"),
"Encrypted file should start with ENCRYPTED:"
);
assert!(
!content.contains("[llm]"),
"Encrypted content should not be readable"
);
}
}
mod config_import {
use super::*;
#[test]
fn test_import_plain_config() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let import_path = temp_dir.path().join("import.toml");
let plain_config = r#"
version = "1"
[llm]
provider = "openai"
model = "gpt-4"
max_tokens = 1000
temperature = 0.7
timeout = 60
api_key_storage = "keyring"
[commit]
format = "conventional"
auto_generate = true
allow_empty = false
gpg_sign = false
max_subject_length = 100
require_scope = false
require_body = false
body_required_types = ["feat", "fix"]
[tag]
version_prefix = "v"
auto_generate = true
gpg_sign = false
include_changelog = true
[changelog]
path = "CHANGELOG.md"
auto_generate = true
format = "keep-a-changelog"
include_hashes = false
include_authors = false
group_by_type = true
[theme]
colors = true
icons = true
date_format = "%Y-%m-%d"
[language]
output_language = "en"
keep_types_english = true
keep_changelog_types_english = true
"#;
fs::write(&import_path, plain_config).unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"import",
"--config",
config_path.to_str().unwrap(),
"--file",
import_path.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("Configuration imported"));
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"get",
"llm.provider",
"--config",
config_path.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("openai"));
}
#[test]
fn test_import_encrypted_config() {
let temp_dir = TempDir::new().unwrap();
let config_path1 = temp_dir.path().join("config1.toml");
let config_path2 = temp_dir.path().join("config2.toml");
let export_path = temp_dir.path().join("encrypted.toml");
init_quicommit(&config_path1);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"set",
"llm.provider",
"anthropic",
"--config",
config_path1.to_str().unwrap(),
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path1.to_str().unwrap(),
"--output",
export_path.to_str().unwrap(),
"--password",
"secure_password",
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"import",
"--config",
config_path2.to_str().unwrap(),
"--file",
export_path.to_str().unwrap(),
"--password",
"secure_password",
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("Configuration imported"));
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"get",
"llm.provider",
"--config",
config_path2.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("anthropic"));
}
#[test]
fn test_import_encrypted_wrong_password() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let export_path = temp_dir.path().join("encrypted.toml");
init_quicommit(&config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path.to_str().unwrap(),
"--output",
export_path.to_str().unwrap(),
"--password",
"correct_password",
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"import",
"--config",
config_path.to_str().unwrap(),
"--file",
export_path.to_str().unwrap(),
"--password",
"wrong_password",
]);
cmd.assert()
.failure()
.stderr(predicate::str::contains("Failed to decrypt"));
}
}
mod config_export_import_roundtrip {
use super::*;
#[test]
fn test_roundtrip_plain() {
let temp_dir = TempDir::new().unwrap();
let config_path1 = temp_dir.path().join("config1.toml");
let config_path2 = temp_dir.path().join("config2.toml");
let export_path = temp_dir.path().join("export.toml");
init_quicommit(&config_path1);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"set",
"llm.model",
"gpt-4-turbo",
"--config",
config_path1.to_str().unwrap(),
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path1.to_str().unwrap(),
"--output",
export_path.to_str().unwrap(),
"--password",
"",
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"import",
"--config",
config_path2.to_str().unwrap(),
"--file",
export_path.to_str().unwrap(),
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"get",
"llm.model",
"--config",
config_path2.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("gpt-4-turbo"));
}
#[test]
fn test_roundtrip_encrypted() {
let temp_dir = TempDir::new().unwrap();
let config_path1 = temp_dir.path().join("config1.toml");
let config_path2 = temp_dir.path().join("config2.toml");
let export_path = temp_dir.path().join("encrypted.toml");
let password = "my_secure_password_123";
init_quicommit(&config_path1);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"set",
"llm.provider",
"deepseek",
"--config",
config_path1.to_str().unwrap(),
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"set",
"llm.model",
"deepseek-chat",
"--config",
config_path1.to_str().unwrap(),
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path1.to_str().unwrap(),
"--output",
export_path.to_str().unwrap(),
"--password",
password,
]);
cmd.assert().success();
let exported_content = fs::read_to_string(&export_path).unwrap();
assert!(exported_content.starts_with("ENCRYPTED:"));
assert!(!exported_content.contains("deepseek"));
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"import",
"--config",
config_path2.to_str().unwrap(),
"--file",
export_path.to_str().unwrap(),
"--password",
password,
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"get",
"llm.provider",
"--config",
config_path2.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("deepseek"));
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"get",
"llm.model",
"--config",
config_path2.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("deepseek-chat"));
}
}

View File

@@ -1,4 +1,4 @@
use assert_cmd::Command;
use assert_cmd::cargo::cargo_bin_cmd;
use predicates::prelude::*;
use std::fs;
use std::path::PathBuf;
@@ -47,22 +47,43 @@ fn create_commit(dir: &PathBuf, message: &str) {
.expect("Failed to create commit");
}
fn setup_git_repo(dir: &PathBuf) {
create_git_repo(dir);
configure_git_user(dir);
}
fn setup_test_repo_with_file(dir: &PathBuf, file_name: &str, file_content: &str) {
setup_git_repo(dir);
create_test_file(dir, file_name, file_content);
stage_file(dir, file_name);
}
fn init_quicommit(dir: &PathBuf, config_path: &PathBuf) {
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(dir);
cmd.assert().success();
}
mod cli_basic {
use super::*;
#[test]
fn test_help() {
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.arg("--help");
cmd.assert()
.success()
.stdout(predicate::str::contains("QuiCommit"))
.stdout(predicate::str::contains("AI-powered Git assistant"));
.stdout(predicate::str::contains("AI-powered Git assistant"))
.stdout(predicate::str::contains("Usage:"))
.stdout(predicate::str::contains("Commands:"))
.stdout(predicate::str::contains("Options:"));
}
#[test]
fn test_version() {
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.arg("--version");
cmd.assert()
.success()
@@ -71,7 +92,7 @@ mod cli_basic {
#[test]
fn test_no_args_shows_help() {
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.assert()
.failure()
.stderr(predicate::str::contains("Usage:"));
@@ -85,9 +106,15 @@ mod cli_basic {
create_git_repo(&repo_path);
configure_git_user(&repo_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["-vv", "init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"-vv",
"init",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert().success();
}
@@ -101,7 +128,7 @@ mod init_command {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert()
@@ -114,7 +141,7 @@ mod init_command {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
@@ -131,7 +158,7 @@ mod init_command {
let config_path = repo_path.join("test_config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
@@ -143,12 +170,18 @@ mod init_command {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["init", "--yes", "--reset", "--config", config_path.to_str().unwrap()]);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"init",
"--yes",
"--reset",
"--config",
config_path.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("initialized successfully"));
@@ -163,7 +196,7 @@ mod profile_command {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["profile", "list", "--config", config_path.to_str().unwrap()]);
cmd.assert()
@@ -176,11 +209,11 @@ mod profile_command {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["profile", "list", "--config", config_path.to_str().unwrap()]);
cmd.assert()
@@ -197,11 +230,11 @@ mod config_command {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["config", "show", "--config", config_path.to_str().unwrap()]);
cmd.assert()
@@ -214,11 +247,11 @@ mod config_command {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["config", "path", "--config", config_path.to_str().unwrap()]);
cmd.assert()
@@ -235,13 +268,19 @@ mod commit_command {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["commit", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(temp_dir.path());
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"commit",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(temp_dir.path());
cmd.assert()
.failure()
@@ -252,19 +291,23 @@ mod commit_command {
fn test_commit_no_changes() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path);
configure_git_user(&repo_path);
setup_git_repo(&repo_path);
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["commit", "--manual", "-m", "test: empty", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"commit",
"--manual",
"-m",
"test: empty",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert()
.success()
@@ -275,22 +318,23 @@ mod commit_command {
fn test_commit_with_staged_changes() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path);
configure_git_user(&repo_path);
create_test_file(&repo_path, "test.txt", "Hello, World!");
stage_file(&repo_path, "test.txt");
setup_test_repo_with_file(&repo_path, "test.txt", "Hello, World!");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["commit", "--manual", "-m", "test: add test file", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"commit",
"--manual",
"-m",
"test: add test file",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert()
.success()
@@ -301,27 +345,52 @@ mod commit_command {
fn test_commit_date_mode() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path);
configure_git_user(&repo_path);
create_test_file(&repo_path, "daily.txt", "Daily update");
stage_file(&repo_path, "daily.txt");
setup_test_repo_with_file(&repo_path, "daily.txt", "Daily update");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["commit", "--date", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"commit",
"--date",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert()
.success()
.stdout(predicate::str::contains("Dry run"));
}
#[test]
fn test_commit_with_think_flag() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
setup_test_repo_with_file(&repo_path, "test.txt", "Hello, World!");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"commit",
"--think",
"--manual",
"-m",
"test: think flag",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert().success();
}
}
mod tag_command {
@@ -332,13 +401,19 @@ mod tag_command {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["tag", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(temp_dir.path());
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"tag",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(temp_dir.path());
cmd.assert()
.failure()
@@ -349,28 +424,60 @@ mod tag_command {
fn test_tag_list_empty() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path);
configure_git_user(&repo_path);
setup_git_repo(&repo_path);
create_test_file(&repo_path, "test.txt", "content");
stage_file(&repo_path, "test.txt");
create_commit(&repo_path, "feat: initial commit");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["tag", "--name", "v0.1.0", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"tag",
"--name",
"v0.1.0",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert()
.success()
.stdout(predicate::str::contains("v0.1.0"));
}
#[test]
fn test_tag_with_think_flag() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
setup_git_repo(&repo_path);
create_test_file(&repo_path, "test.txt", "content");
stage_file(&repo_path, "test.txt");
create_commit(&repo_path, "feat: initial commit");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"tag",
"--think",
"--name",
"v0.2.0",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert().success();
}
}
mod changelog_command {
@@ -380,20 +487,23 @@ mod changelog_command {
fn test_changelog_init() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path);
configure_git_user(&repo_path);
setup_git_repo(&repo_path);
let config_path = repo_path.join("config.toml");
let changelog_path = repo_path.join("CHANGELOG.md");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert().success();
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["changelog", "--init", "--output", changelog_path.to_str().unwrap(), "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"changelog",
"--init",
"--output",
changelog_path.to_str().unwrap(),
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert().success();
@@ -404,26 +514,26 @@ mod changelog_command {
fn test_changelog_dry_run() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path);
configure_git_user(&repo_path);
setup_git_repo(&repo_path);
create_test_file(&repo_path, "test.txt", "content");
stage_file(&repo_path, "test.txt");
create_commit(&repo_path, "feat: add feature");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"changelog",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["changelog", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert()
.success();
}
}
@@ -435,7 +545,7 @@ mod cross_platform {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("subdir").join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
@@ -449,7 +559,7 @@ mod cross_platform {
fs::create_dir_all(&space_dir).unwrap();
let config_path = space_dir.join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
@@ -463,7 +573,7 @@ mod cross_platform {
fs::create_dir_all(&unicode_dir).unwrap();
let config_path = unicode_dir.join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
@@ -532,22 +642,23 @@ mod validators {
fn test_commit_message_validation() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path);
configure_git_user(&repo_path);
create_test_file(&repo_path, "test.txt", "content");
stage_file(&repo_path, "test.txt");
setup_test_repo_with_file(&repo_path, "test.txt", "content");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["commit", "--manual", "-m", "invalid commit message without type", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"commit",
"--manual",
"-m",
"invalid commit message without type",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert()
.failure()
@@ -558,22 +669,23 @@ mod validators {
fn test_valid_conventional_commit() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path);
configure_git_user(&repo_path);
create_test_file(&repo_path, "test.txt", "content");
stage_file(&repo_path, "test.txt");
setup_test_repo_with_file(&repo_path, "test.txt", "content");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["commit", "--manual", "-m", "feat: add new feature", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"commit",
"--manual",
"-m",
"feat: add new feature",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert()
.success()
@@ -588,22 +700,23 @@ mod subcommands {
fn test_commit_alias() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path);
configure_git_user(&repo_path);
create_test_file(&repo_path, "test.txt", "content");
stage_file(&repo_path, "test.txt");
setup_test_repo_with_file(&repo_path, "test.txt", "content");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["c", "--manual", "-m", "fix: test", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"c",
"--manual",
"-m",
"fix: test",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert()
.success()
@@ -615,7 +728,7 @@ mod subcommands {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["i", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert()
@@ -628,11 +741,11 @@ mod subcommands {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["p", "list", "--config", config_path.to_str().unwrap()]);
cmd.assert()
@@ -640,3 +753,79 @@ mod subcommands {
.stdout(predicate::str::contains("default"));
}
}
mod edge_cases {
use super::*;
#[test]
fn test_config_file_not_found() {
let temp_dir = TempDir::new().unwrap();
let non_existent_config = temp_dir.path().join("non_existent_config.toml");
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"show",
"--config",
non_existent_config.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("QuiCommit Configuration"))
.stdout(predicate::str::contains("Default profile: (none)"))
.stdout(predicate::str::contains("Profiles: 0"));
}
#[test]
fn test_invalid_git_repo() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
let config_path = repo_path.join("config.toml");
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"commit",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert()
.failure()
.stderr(predicate::str::contains("git").or(predicate::str::contains("repository")));
}
#[test]
fn test_empty_commit_message() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
setup_test_repo_with_file(&repo_path, "test.txt", "content");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"commit",
"--manual",
"-m",
"",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert().failure().stderr(predicate::str::contains(
"Invalid conventional commit format",
));
}
}