23 Commits

Author SHA1 Message Date
459670f363 fix(llm): 修复 Kimi API temperature 参数配置 2026-06-01 17:48:42 +08:00
7636d0b5a6 feat(llm): 统一思考模式配置,支持显式禁用状态 2026-06-01 17:39:36 +08:00
928ebb61b4 refactor(llm): renumber system prompt rules 2026-05-27 15:37:51 +08:00
7e85cdd8b0 chore(release): 升级版本号至 0.3.0 2026-05-27 15:16:26 +08:00
90074e6e32 style: 格式化代码并优化导入顺序 2026-05-27 15:15:15 +08:00
b8182e7538 修复kimi返回信息的读取错误 2026-05-27 14:50:47 +08:00
4331b9306e LLM支持优化 2026-05-26 17:43:42 +08:00
a08bc809bb 修复bug 2026-05-26 16:30:28 +08:00
1063369d96 feat(deepseek): 添加 DeepSeek reasoning 模式支持 2026-05-26 16:27:49 +08:00
3a57d25a76 docs: 添加 QuiCommit 项目路线图文档 2026-05-14 17:04:13 +08:00
8152edba39 chore: 删除构建输出日志文件 2026-05-13 13:54:50 +08:00
679db5b1db chore: 清理大量未使用的变量、方法及结构体警告 2026-05-13 13:54:20 +08:00
b1ad68c7b5 build: 升级版本号至 0.2.0 2026-05-13 12:08:03 +08:00
280d6ec5c9 feat(generator): 按文件重要性对暂存差异排序 2026-05-13 12:07:01 +08:00
68427c4a11 chore: 发布 v0.1.11 并更新文档 2026-03-23 18:07:11 +08:00
8dd9e85b77 feat(config): 在加密导出/导入中包含个人访问令牌 2026-03-23 17:59:23 +08:00
0c7d2ad518 refactor: 移除未使用的代码和注释掉的辅助函数 2026-03-20 18:05:33 +08:00
0289dd4684 chore: 升级版本至 0.1.10 并更新密钥环与加密相关描述 2026-03-19 16:34:45 +08:00
e2d43315e3 feat: 新增系统密钥环安全存储API密钥与自动生成变更日志功能 2026-03-12 18:05:38 +08:00
0e1c2c6350 chore: 移除测试 keyring 功能的临时文件 2026-03-12 17:45:05 +08:00
da85fc94b1 feat(keyring): 集成系统密钥环安全存储 API key 2026-03-12 17:42:41 +08:00
c66d782eab chore: 更新版本号至 0.1.9 并补充 changelog 2026-03-06 16:31:46 +08:00
358b44ab81 fix(generator): 修复diff截断时的字符边界问题 2026-03-06 16:28:26 +08:00
41 changed files with 7352 additions and 2851 deletions

2
.gitignore vendored
View File

@@ -21,3 +21,5 @@ test_output/
# Config (for development) # Config (for development)
config.toml config.toml
.claude/
CLAUDE.md

View File

@@ -5,6 +5,55 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
暂无。
## [0.3.1] - 2026-06-01
### ✨ 新功能
- 按文件重要性对暂存差异排序,优先处理核心变更
- DeepSeek 新增 reasoning 推理模式支持
- LLM 统一思考模式配置,支持显式启用/禁用思考状态
- 新增 `thinking.rs` 思考状态管理模块
### 🐞 错误修复
- 修复 Kimi 返回信息的读取错误
- 修复 DeepSeek 和 Kimi 流式响应的解析问题
### 📚 文档
- 新增 ROADMAP.md 项目路线图文档
### 🔧 其他变更
- LLM 模块大规模重构所有提供商Anthropic、DeepSeek、Kimi、Ollama、OpenAI、OpenRouter适配流式响应处理
- 代码格式化并优化导入顺序
- 清理大量未使用的变量、方法及结构体警告
- 清理构建输出日志文件
- 重新编号 LLM 系统提示规则
- i18n 多语言消息格式修复
- 各命令模块commit、tag、changelog、config、profile、init持续优化
## [0.1.11] - 2026-03-23
### ✨ 新功能
- 新增配置导出导入功能,支持加密保护
- Profile 支持 Token 管理PAT 等)
- 自动生成和维护 Keep a Changelog 格式的变更日志
- 交互式命令行界面,支持预览和确认
### 🔐 安全特性
- 敏感数据加密存储API 密钥等)
- 使用系统密钥环安全保存凭证
### 🔧 其他变更
- 优化 diff 截断逻辑,使用字符边界确保多字节字符安全
- 改进配置管理器,支持修改追踪
## [0.1.9] - 2026-03-06
### 🐞 错误修复
- 修复diff截断时的字符边界问题
## [0.1.7] - 2026-02-14 ## [0.1.7] - 2026-02-14
### 🐞 错误修复 ### 🐞 错误修复

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "quicommit" name = "quicommit"
version = "0.1.8" version = "0.3.1"
edition = "2024" edition = "2024"
authors = ["Sidney Zhang <zly@lyzhang.me>"] authors = ["Sidney Zhang <zly@lyzhang.me>"]
description = "A powerful Git assistant tool with AI-powered commit/tag/changelog generation(alpha version)" description = "A powerful Git assistant tool with AI-powered commit/tag/changelog generation(alpha version)"
@@ -33,7 +33,7 @@ git2 = "0.20.3"
which = "6.0" which = "6.0"
# HTTP client for LLM APIs # HTTP client for LLM APIs
reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false } reqwest = { version = "0.12", features = ["json", "rustls-tls", "stream"], default-features = false }
tokio = { version = "1.35", features = ["full"] } tokio = { version = "1.35", features = ["full"] }
# Error handling # Error handling
@@ -57,6 +57,7 @@ sha2 = "0.10"
hex = "0.4" hex = "0.4"
textwrap = "0.16" textwrap = "0.16"
async-trait = "0.1" async-trait = "0.1"
futures-util = "0.3"
serde_json = "1.0" serde_json = "1.0"
atty = "0.2" atty = "0.2"
@@ -66,6 +67,9 @@ argon2 = "0.5"
rand = "0.8" rand = "0.8"
base64 = "0.22" base64 = "0.22"
# System keyring for secure API key storage
keyring = { version = "3.6.3", features = ["apple-native", "windows-native", "sync-secret-service"] }
# Interactive editor # Interactive editor
edit = "0.1" edit = "0.1"
@@ -80,11 +84,13 @@ mockall = "0.12"
wiremock = "0.6" wiremock = "0.6"
[profile.release] [profile.release]
opt-level = 3 opt-level = "s"
lto = true lto = "thin"
codegen-units = 1 codegen-units = 1
panic = "abort"
strip = true strip = true
debug = false
[profile.dev] [profile.dev]
opt-level = 0 opt-level = 1
debug = true debug = true

128
RAODMAP.md Normal file
View File

@@ -0,0 +1,128 @@
# QuiCommit Roadmap
## 已完成 ✅
- [x] 基础 Git 操作commit、tag、changelog
- [x] AI 驱动提交信息生成Conventional Commits / commitlint 格式)
- [x] 多 LLM 提供商支持Ollama、OpenAI、Anthropic、Kimi、DeepSeek、OpenRouter
- [x] 多 Git Profile 管理SSH 密钥 + GPG 签名)
- [x] 语义化版本自动升级与 AI 发布说明
- [x] Keep a Changelog 格式自动生成
- [x] 系统密钥环安全存储 API Key
- [x] 敏感数据加密存储AES-GCM + Argon2
- [x] 交互式 CLI 预览与确认
- [x] 7 种语言国际化支持
- [x] 配置导出/导入(支持加密保护)
- [x] Profile Token 管理PAT 等)
---
## 进行中 🚧
暂无。
---
## 计划中 📋
### 1. Git 凭证管理器
将 Git 凭证管理集成到 QuiCommit 中,统一管理 HTTPS 仓库的身份认证。
- [ ] **Git Credential Helper 集成**
- 实现 `git credential-store` / `git-credential-libsecret` 等标准的 credential helper 协议
- 支持 `quicommit credential get|store|erase` 子命令
- 与系统密钥环无缝对接,复用已有的 `KeyringManager`
- [ ] **凭证管理 CLI**
- `quicommit credential list` — 列出所有已存储的凭证
- `quicommit credential add` — 手动添加凭证(用户名 + 密码/Token
- `quicommit credential remove` — 删除指定凭证
- `quicommit credential status` — 查看凭证管理状态
- [ ] **跨平台支持**
- Windows集成 Windows Credential Manager
- macOS集成 Keychain
- Linux通过 Secret Service / D-Bus 对接 GNOME Keyring / KWallet
- [ ] **安全增强**
- 支持 PATPersonal Access Token按 scope / 有效期管理
- 支持凭证过期检查和自动提醒
---
### 2. 新增模型支持
扩展 LLM 提供商和模型覆盖范围,满足更多场景和偏好。
- [x] **新增 DeepSeek 最新模型**
- 支持 `deepseek-chat`DeepSeek-V3
- 支持 `deepseek-reasoner`DeepSeek-R1
- 支持 `deepseek-v4`
- [ ] **新增国内模型提供商**
- 通义千问 (Qwen) — 阿里云 DashScope API
- 文心一言 (ERNIE) — 百度千帆 API
- 智谱 GLM — ChatGLM API
- 百川 (Baichuan) — Baichuan API
- [ ] **新增国际模型提供商**
- Google Gemini API
- Mistral AI API
- Cohere API
- Groq (LPU 推理加速)
- [ ] **本地模型扩展**
- 支持 llama.cpp 服务端(兼容 OpenAI API 格式)
- 支持 vLLM 部署的模型
- 本地模型推荐列表与一键配置向导
- [ ] **模型能力适配**
- 不同模型的 token 限制自适应
- 模型特定的 prompt 模板优化
- 支持 function calling / tool use用于复杂生成场景
---
### 3. 生成体验优化
提升 AI 生成提交信息、标签说明和变更日志时的用户体验。
- [x] **流式输出与实时反馈**
- 支持 SSEServer-Sent Events流式生成
- 终端打字机效果实时显示生成内容
- 流式生成过程中支持 `Ctrl+C` 中断
- [ ] **生成质量提升**
- 基于 commitlint 规则的后校验与自动修正
- 支持 Few-shot 示例引导(用户可自定义示例库)
- 生成结果的置信度评分与多候选方案
- [ ] **Diff 上下文增强**
- 智能 diff 摘要(大改动时自动压缩关键信息)
- 支持 `.gitattributes` 排除/包含规则
- 按文件类型分组生成更精准的提交描述
- [ ] **交互式编辑增强**
- 生成后支持内联编辑(类似 `git rebase -i` 体验)
- 支持重新生成指定部分(如 scope、description
- 历史提交信息学习与风格适配
- [ ] **批量操作支持**
- 批量生成多个 commit分组暂存区变更
- `--dry-run` 预览模式(只生成本地查看,不写 Git
- [ ] **性能优化**
- API 请求并发优化(多个模型同时生成候选)
- 本地缓存常用 prompt 模板
- 减少不必要的 diff 计算
---
## 长远规划 🌟
- [ ] **VS Code 扩展** — 在 IDE 内直接使用 QuiCommit
- [ ] **GitHub Action / GitLab CI 集成** — 自动化 PR 标题和描述生成
- [ ] **团队协作** — 共享 commit 风格配置、prompt 模板库
- [ ] **Web Dashboard** — 可视化管理多仓库的 Git 活动与 AI 生成统计
- [ ] **插件系统** — 允许社区贡献自定义 LLM 提供商和生成策略

View File

@@ -6,8 +6,12 @@ A powerful AI-powered Git assistant for generating conventional commits, tags, a
[Still in early development, some features may not be complete. Feedback and contributions are welcome.] [Still in early development, some features may not be complete. Feedback and contributions are welcome.]
> ⚠️ **Important Notice**: QuiCommit now uses system keyring to store API keys securely. This change may cause breaking changes to your existing configuration. If you encounter issues after updating, please run `quicommit config reset --force` to reset your configuration, then reconfigure your settings.
![Rust](https://img.shields.io/badge/rust-%23000000.svg?style=for-the-badge&logo=rust&logoColor=white) ![Rust](https://img.shields.io/badge/rust-%23000000.svg?style=for-the-badge&logo=rust&logoColor=white)
![License](https://img.shields.io/badge/license-MIT-blue.svg) ![License](https://img.shields.io/badge/license-MIT-blue.svg)
![Crates.io Version](https://img.shields.io/crates/v/quicommit)
## Features ## Features
@@ -16,8 +20,10 @@ A powerful AI-powered Git assistant for generating conventional commits, tags, a
- **Profile Management**: Manage multiple Git identities with SSH keys and GPG signing support - **Profile Management**: Manage multiple Git identities with SSH keys and GPG signing support
- **Smart Tagging**: Semantic version bumping with AI-generated release notes - **Smart Tagging**: Semantic version bumping with AI-generated release notes
- **Changelog Generation**: Automatic changelog generation in Keep a Changelog format - **Changelog Generation**: Automatic changelog generation in Keep a Changelog format
- **Security**: Encrypt sensitive data - **Security**: Use system keyring to store API keys securely
- **Interactive UI**: Beautiful CLI with previews and confirmations - **Interactive UI**: Beautiful CLI with previews and confirmations
- **Multi-language Support**: Output in 7 languages (English, Chinese, Japanese, Korean, Spanish, French, German)
- **Config Export/Import**: Backup and restore configuration with optional encryption
## Installation ## Installation
@@ -159,30 +165,30 @@ quicommit profile token
```bash ```bash
# Configure Ollama (local) # Configure Ollama (local)
quicommit config set-llm ollama quicommit config set-llm ollama
quicommit config set-ollama --url http://localhost:11434 --model llama2 quicommit config set-llm ollama --url http://localhost:11434 --model llama2
# Configure OpenAI # Configure OpenAI
quicommit config set-llm openai quicommit config set-llm openai
quicommit config set-openai-key YOUR_API_KEY quicommit config set-api-key YOUR_API_KEY
# Configure Anthropic Claude # Configure Anthropic Claude
quicommit config set-llm anthropic quicommit config set-llm anthropic
quicommit config set-anthropic-key YOUR_API_KEY quicommit config set-api-key YOUR_API_KEY
# Configure Kimi (Moonshot AI) # Configure Kimi (Moonshot AI)
quicommit config set-llm kimi quicommit config set-llm kimi
quicommit config set-kimi-key YOUR_API_KEY quicommit config set-api-key YOUR_API_KEY
quicommit config set-kimi --base-url https://api.moonshot.cn/v1 --model moonshot-v1-8k quicommit config set-llm kimi --base-url https://api.moonshot.cn/v1 --model moonshot-v1-8k
# Configure DeepSeek # Configure DeepSeek
quicommit config set-llm deepseek quicommit config set-llm deepseek
quicommit config set-deepseek-key YOUR_API_KEY quicommit config set-api-key YOUR_API_KEY
quicommit config set-deepseek --base-url https://api.deepseek.com/v1 --model deepseek-chat quicommit config set-llm deepseek --base-url https://api.deepseek.com/v1 --model deepseek-chat
# Configure OpenRouter # Configure OpenRouter
quicommit config set-llm openrouter quicommit config set-llm openrouter
quicommit config set-openrouter-key YOUR_API_KEY quicommit config set-api-key YOUR_API_KEY
quicommit config set-openrouter --base-url https://openrouter.ai/api/v1 --model openai/gpt-4 quicommit config set-llm openrouter --base-url https://openrouter.ai/api/v1 --model openai/gpt-4
# Set commit format # Set commit format
quicommit config set-commit-format conventional quicommit config set-commit-format conventional
@@ -193,7 +199,7 @@ quicommit config set-version-prefix v
# Set changelog path # Set changelog path
quicommit config set-changelog-path CHANGELOG.md quicommit config set-changelog-path CHANGELOG.md
# Set output language # Set output language (en, zh, ja, ko, es, fr, de)
quicommit config set-language en quicommit config set-language en
# Set keep commit types in English # Set keep commit types in English
@@ -205,8 +211,22 @@ quicommit config set-keep-changelog-types-english true
# Test LLM connection # Test LLM connection
quicommit config test-llm quicommit config test-llm
# Check keyring availability
quicommit config check-keyring
# Show config file path
quicommit config path
# Export configuration (with optional encryption)
quicommit config export -o config-backup.toml
quicommit config export -o config-backup.enc --password
# Import configuration
quicommit config import -i config-backup.toml
quicommit config import -i config-backup.enc --password
# Reset configuration to defaults # Reset configuration to defaults
quicommit config reset quicommit config reset --force
``` ```
## Command Reference ## Command Reference
@@ -396,17 +416,31 @@ quicommit config set llm.provider ollama
# Get configuration value # Get configuration value
quicommit config get llm.provider quicommit config get llm.provider
# Set API key (stored in system keyring)
quicommit config set-api-key YOUR_API_KEY
# Delete API key from keyring
quicommit config delete-api-key
# Test LLM connection # Test LLM connection
quicommit config test-llm quicommit config test-llm
# List available models # List available models
quicommit config list-models quicommit config list-models
# Export configuration # Check keyring availability
quicommit config check-keyring
# Show config file path
quicommit config path
# Export configuration (with optional encryption)
quicommit config export -o config-backup.toml quicommit config export -o config-backup.toml
quicommit config export -o config-backup.enc --password
# Import configuration # Import configuration
quicommit config import -i config-backup.toml quicommit config import -i config-backup.toml
quicommit config import -i config-backup.enc --password
# Reset configuration # Reset configuration
quicommit config reset --force quicommit config reset --force

View File

@@ -1,12 +1,13 @@
use std::env; use std::env;
fn main() { fn main() {
// Only generate completions when explicitly requested // Only generate completions when explicitly requested
if env::var("GENERATE_COMPLETIONS").is_ok() { if env::var("GENERATE_COMPLETIONS").is_ok() {
println!("cargo:warning=To generate shell completions, run: cargo run --bin quicommit -- completions"); println!(
"cargo:warning=To generate shell completions, run: cargo run --bin quicommit -- completions"
);
} }
// Rerun if build.rs changes // Rerun if build.rs changes
println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=build.rs");
} }

View File

@@ -4,6 +4,13 @@
# - macOS: ~/Library/Application Support/quicommit/config.toml # - macOS: ~/Library/Application Support/quicommit/config.toml
# - Windows: %APPDATA%\quicommit\config.toml # - Windows: %APPDATA%\quicommit\config.toml
# ⚠️ IMPORTANT: Keyring Feature Update
# QuiCommit now uses system keyring to store API keys securely.
# This change may cause breaking changes to your existing configuration.
# If you encounter issues after updating, please reset your configuration:
# quicommit config reset --force
# Then reconfigure your settings using the CLI commands.
# Configuration version (for migration) # Configuration version (for migration)
version = "1" version = "1"

View File

@@ -6,8 +6,11 @@
【目前还处在早期开发阶段,依然有一些功能未完善,欢迎反馈和贡献。】 【目前还处在早期开发阶段,依然有一些功能未完善,欢迎反馈和贡献。】
> ⚠️ **重要提示**QuiCommit 现在使用系统密钥环keyring来安全存储 API 密钥。此更改可能会对现有配置造成破坏性变更。如果在更新后遇到问题,请运行 `quicommit config reset --force` 重置配置,然后重新配置您的设置。
![Rust](https://img.shields.io/badge/rust-%23000000.svg?style=for-the-badge&logo=rust&logoColor=white) ![Rust](https://img.shields.io/badge/rust-%23000000.svg?style=for-the-badge&logo=rust&logoColor=white)
![License](https://img.shields.io/badge/license-MIT-blue.svg) ![License](https://img.shields.io/badge/license-MIT-blue.svg)
![Crates.io Version](https://img.shields.io/crates/v/quicommit)
## 主要功能 ## 主要功能
@@ -16,8 +19,10 @@
- **多配置管理**为不同场景管理多个Git身份支持SSH密钥和GPG签名配置 - **多配置管理**为不同场景管理多个Git身份支持SSH密钥和GPG签名配置
- **智能标签管理**基于语义版本自动检测升级AI生成标签信息 - **智能标签管理**基于语义版本自动检测升级AI生成标签信息
- **变更日志生成**自动生成Keep a Changelog格式的变更日志 - **变更日志生成**自动生成Keep a Changelog格式的变更日志
- **安全保护**加密存储敏感数据 - **安全保护**使用系统密钥环进行安全存储
- **交互式界面**美观的CLI界面支持预览和确认 - **交互式界面**美观的CLI界面支持预览和确认
- **多语言支持**支持7种语言输出中文、英语、日语、韩语、西班牙语、法语、德语
- **配置导出导入**:备份和恢复配置,支持加密保护
## 安装 ## 安装
@@ -159,30 +164,30 @@ quicommit profile token
```bash ```bash
# 配置Ollama本地 # 配置Ollama本地
quicommit config set-llm ollama quicommit config set-llm ollama
quicommit config set-ollama --url http://localhost:11434 --model llama2 quicommit config set-llm ollama --url http://localhost:11434 --model llama2
# 配置OpenAI # 配置OpenAI
quicommit config set-llm openai quicommit config set-llm openai
quicommit config set-openai-key YOUR_API_KEY quicommit config set-api-key YOUR_API_KEY
# 配置Anthropic Claude # 配置Anthropic Claude
quicommit config set-llm anthropic quicommit config set-llm anthropic
quicommit config set-anthropic-key YOUR_API_KEY quicommit config set-api-key YOUR_API_KEY
# 配置Kimi # 配置Kimi
quicommit config set-llm kimi quicommit config set-llm kimi
quicommit config set-kimi-key YOUR_API_KEY quicommit config set-api-key YOUR_API_KEY
quicommit config set-kimi --base-url https://api.moonshot.cn/v1 --model moonshot-v1-8k quicommit config set-llm kimi --base-url https://api.moonshot.cn/v1 --model moonshot-v1-8k
# 配置DeepSeek # 配置DeepSeek
quicommit config set-llm deepseek quicommit config set-llm deepseek
quicommit config set-deepseek-key YOUR_API_KEY quicommit config set-api-key YOUR_API_KEY
quicommit config set-deepseek --base-url https://api.deepseek.com/v1 --model deepseek-chat quicommit config set-llm deepseek --base-url https://api.deepseek.com/v1 --model deepseek-chat
# 配置OpenRouter # 配置OpenRouter
quicommit config set-llm openrouter quicommit config set-llm openrouter
quicommit config set-openrouter-key YOUR_API_KEY quicommit config set-api-key YOUR_API_KEY
quicommit config set-openrouter --base-url https://openrouter.ai/api/v1 --model openai/gpt-4 quicommit config set-llm openrouter --base-url https://openrouter.ai/api/v1 --model openai/gpt-4
# 设置提交格式 # 设置提交格式
quicommit config set-commit-format conventional quicommit config set-commit-format conventional
@@ -193,7 +198,7 @@ quicommit config set-version-prefix v
# 设置变更日志路径 # 设置变更日志路径
quicommit config set-changelog-path CHANGELOG.md quicommit config set-changelog-path CHANGELOG.md
# 设置输出语言 # 设置输出语言zh, en, ja, ko, es, fr, de
quicommit config set-language zh quicommit config set-language zh
# 设置保持提交类型为英文 # 设置保持提交类型为英文
@@ -205,8 +210,22 @@ quicommit config set-keep-changelog-types-english true
# 测试LLM连接 # 测试LLM连接
quicommit config test-llm quicommit config test-llm
# 检查密钥环可用性
quicommit config check-keyring
# 显示配置文件路径
quicommit config path
# 导出配置(支持加密)
quicommit config export -o config-backup.toml
quicommit config export -o config-backup.enc --password
# 导入配置
quicommit config import -i config-backup.toml
quicommit config import -i config-backup.enc --password
# 重置配置为默认值 # 重置配置为默认值
quicommit config reset quicommit config reset --force
``` ```
## 命令参考 ## 命令参考
@@ -396,17 +415,31 @@ quicommit config set llm.provider ollama
# 获取配置值 # 获取配置值
quicommit config get llm.provider quicommit config get llm.provider
# 设置API密钥存储在系统密钥环中
quicommit config set-api-key YOUR_API_KEY
# 从密钥环删除API密钥
quicommit config delete-api-key
# 测试LLM连接 # 测试LLM连接
quicommit config test-llm quicommit config test-llm
# 列出可用模型 # 列出可用模型
quicommit config list-models quicommit config list-models
# 导出配置 # 检查密钥环可用性
quicommit config check-keyring
# 显示配置文件路径
quicommit config path
# 导出配置(支持加密)
quicommit config export -o config-backup.toml quicommit config export -o config-backup.toml
quicommit config export -o config-backup.enc --password
# 导入配置 # 导入配置
quicommit config import -i config-backup.toml quicommit config import -i config-backup.toml
quicommit config import -i config-backup.enc --password
# 重置配置 # 重置配置
quicommit config reset --force quicommit config reset --force

View File

@@ -1,4 +1,4 @@
use anyhow::{bail, Result}; use anyhow::{Result, bail};
use chrono::Utc; use chrono::Utc;
use clap::Parser; use clap::Parser;
use colored::Colorize; use colored::Colorize;
@@ -8,7 +8,7 @@ use std::path::PathBuf;
use crate::config::{Language, manager::ConfigManager}; use crate::config::{Language, manager::ConfigManager};
use crate::generator::ContentGenerator; use crate::generator::ContentGenerator;
use crate::git::find_repo; use crate::git::find_repo;
use crate::git::{changelog::*, CommitInfo}; use crate::git::{CommitInfo, changelog::*};
use crate::i18n::{Messages, translate_changelog_category}; use crate::i18n::{Messages, translate_changelog_category};
/// Generate changelog /// Generate changelog
@@ -55,6 +55,10 @@ pub struct ChangelogCommand {
#[arg(long)] #[arg(long)]
dry_run: bool, dry_run: bool,
/// Enable thinking mode for this changelog (override config)
#[arg(long)]
think: bool,
/// Skip interactive prompts /// Skip interactive prompts
#[arg(short = 'y', long)] #[arg(short = 'y', long)]
yes: bool, yes: bool,
@@ -74,18 +78,20 @@ impl ChangelogCommand {
// Initialize changelog if requested // Initialize changelog if requested
if self.init { if self.init {
let path = self.output.as_ref() let path = self
.map(|p| p.clone()) .output
.clone()
.unwrap_or_else(|| PathBuf::from(&config.changelog.path)); .unwrap_or_else(|| PathBuf::from(&config.changelog.path));
init_changelog(&path)?; init_changelog(&path)?;
println!("{}", messages.initialized_changelog(&format!("{:?}", path))); println!("{}", messages.initialized_changelog(&format!("{:?}", path)));
return Ok(()); return Ok(());
} }
// Determine output path // Determine output path
let output_path = self.output.as_ref() let output_path = self
.map(|p| p.clone()) .output
.clone()
.unwrap_or_else(|| PathBuf::from(&config.changelog.path)); .unwrap_or_else(|| PathBuf::from(&config.changelog.path));
// Determine format // Determine format
@@ -94,7 +100,10 @@ impl ChangelogCommand {
Some("keep") | Some("keep-a-changelog") => ChangelogFormat::KeepAChangelog, Some("keep") | Some("keep-a-changelog") => ChangelogFormat::KeepAChangelog,
Some("custom") => ChangelogFormat::Custom, Some("custom") => ChangelogFormat::Custom,
None => ChangelogFormat::KeepAChangelog, None => ChangelogFormat::KeepAChangelog,
Some(f) => bail!("Unknown format: {}. Use: keep-a-changelog, github-releases", f), Some(f) => bail!(
"Unknown format: {}. Use: keep-a-changelog, github-releases",
f
),
}; };
// Get version // Get version
@@ -112,11 +121,11 @@ impl ChangelogCommand {
// Get commits // Get commits
println!("{}", messages.fetching_commits()); println!("{}", messages.fetching_commits());
let commits = generate_from_history(&repo, self.from.as_deref(), Some(&self.to))?; let commits = generate_from_history(&repo, self.from.as_deref(), Some(&self.to))?;
if commits.is_empty() { if commits.is_empty() {
bail!("{}", messages.no_commits_found()); bail!("{}", messages.no_commits_found());
} }
println!("{}", messages.found_commits(commits.len())); println!("{}", messages.found_commits(commits.len()));
// Generate changelog // Generate changelog
@@ -148,7 +157,7 @@ impl ChangelogCommand {
println!("{}", "".repeat(60)); println!("{}", "".repeat(60));
let confirm = Confirm::new() let confirm = Confirm::new()
.with_prompt(&messages.write_to_file(&format!("{:?}", output_path))) .with_prompt(messages.write_to_file(&format!("{:?}", output_path)))
.default(true) .default(true)
.interact()?; .interact()?;
@@ -168,7 +177,7 @@ impl ChangelogCommand {
} else if existing.starts_with("# Changelog") { } else if existing.starts_with("# Changelog") {
let lines: Vec<&str> = existing.lines().collect(); let lines: Vec<&str> = existing.lines().collect();
let mut header_end = 0; let mut header_end = 0;
for (i, line) in lines.iter().enumerate() { for (i, line) in lines.iter().enumerate() {
if i == 0 && line.starts_with('#') { if i == 0 && line.starts_with('#') {
header_end = i + 1; header_end = i + 1;
@@ -178,10 +187,10 @@ impl ChangelogCommand {
break; break;
} }
} }
let header = lines[..header_end].join("\n"); let header = lines[..header_end].join("\n");
let rest = lines[header_end..].join("\n"); let rest = lines[header_end..].join("\n");
format!("{}\n{}\n{}", header, changelog, rest) format!("{}\n{}\n{}", header, changelog, rest)
} else { } else {
format!("{}{}", CHANGELOG_HEADER, changelog) format!("{}{}", CHANGELOG_HEADER, changelog)
@@ -204,13 +213,14 @@ impl ChangelogCommand {
messages: &Messages, messages: &Messages,
) -> Result<String> { ) -> Result<String> {
let manager = ConfigManager::new()?; let manager = ConfigManager::new()?;
let config = manager.config();
let language = manager.get_language().unwrap_or(Language::English); let language = manager.get_language().unwrap_or(Language::English);
println!("{}", messages.ai_generating_changelog()); println!("{}", messages.ai_generating_changelog());
let generator = ContentGenerator::new(&config.llm).await?; let generator = ContentGenerator::new_with_think(&manager, self.think).await?;
generator.generate_changelog_entry(version, commits, language).await generator
.generate_changelog_entry(version, commits, language)
.await
} }
fn generate_with_template( fn generate_with_template(
@@ -221,14 +231,14 @@ impl ChangelogCommand {
language: Language, language: Language,
) -> Result<String> { ) -> Result<String> {
let manager = ConfigManager::new()?; let manager = ConfigManager::new()?;
let generator = ChangelogGenerator::new() let generator = ChangelogGenerator::new()
.format(format) .format(format)
.include_hashes(self.include_hashes) .include_hashes(self.include_hashes)
.include_authors(self.include_authors); .include_authors(self.include_authors);
let changelog = generator.generate(version, Utc::now(), commits)?; let changelog = generator.generate(version, Utc::now(), commits)?;
// Translate changelog categories if configured // Translate changelog categories if configured
if !manager.keep_changelog_types_english() { if !manager.keep_changelog_types_english() {
Ok(self.translate_changelog_categories(&changelog, language)) Ok(self.translate_changelog_categories(&changelog, language))
@@ -236,14 +246,15 @@ impl ChangelogCommand {
Ok(changelog) Ok(changelog)
} }
} }
fn translate_changelog_categories(&self, changelog: &str, language: Language) -> String { fn translate_changelog_categories(&self, changelog: &str, language: Language) -> String {
let translated = changelog changelog
.lines() .lines()
.map(|line| { .map(|line| {
if line.starts_with("## ") || line.starts_with("### ") { if line.starts_with("## ") || line.starts_with("### ") {
let category = line.trim_start_matches("## ").trim_start_matches("### "); let category = line.trim_start_matches("## ").trim_start_matches("### ");
let translated_category = translate_changelog_category(category, language, false); let translated_category =
translate_changelog_category(category, language, false);
if line.starts_with("## ") { if line.starts_with("## ") {
format!("## {}", translated_category) format!("## {}", translated_category)
} else { } else {
@@ -254,7 +265,6 @@ impl ChangelogCommand {
} }
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n"); .join("\n")
translated
} }
} }

View File

@@ -1,14 +1,14 @@
use anyhow::{bail, Context, Result}; use anyhow::{Context, Result, bail};
use clap::Parser; use clap::Parser;
use colored::Colorize; use colored::Colorize;
use dialoguer::{Confirm, Input, Select}; use dialoguer::{Confirm, Input, Select};
use std::path::PathBuf; use std::path::PathBuf;
use crate::config::{Language, manager::ConfigManager};
use crate::config::CommitFormat; use crate::config::CommitFormat;
use crate::config::{Language, manager::ConfigManager};
use crate::generator::ContentGenerator; use crate::generator::ContentGenerator;
use crate::git::{find_repo, GitRepo};
use crate::git::commit::{CommitBuilder, create_date_commit_message}; use crate::git::commit::{CommitBuilder, create_date_commit_message};
use crate::git::{GitRepo, find_repo};
use crate::i18n::Messages; use crate::i18n::Messages;
use crate::utils::validators::get_commit_types; use crate::utils::validators::get_commit_types;
@@ -71,6 +71,10 @@ pub struct CommitCommand {
#[arg(long)] #[arg(long)]
no_verify: bool, no_verify: bool,
/// Enable thinking mode for this commit (override config)
#[arg(short = 't', long)]
think: bool,
/// Skip interactive prompts /// Skip interactive prompts
#[arg(short = 'y', long)] #[arg(short = 'y', long)]
yes: bool, yes: bool,
@@ -88,7 +92,7 @@ impl CommitCommand {
pub async fn execute(&self, config_path: Option<PathBuf>) -> Result<()> { pub async fn execute(&self, config_path: Option<PathBuf>) -> Result<()> {
// Find git repository // Find git repository
let repo = find_repo(std::env::current_dir()?.as_path())?; let repo = find_repo(std::env::current_dir()?.as_path())?;
// Load configuration // Load configuration
let manager = if let Some(ref path) = config_path { let manager = if let Some(ref path) = config_path {
ConfigManager::with_path(path)? ConfigManager::with_path(path)?
@@ -98,7 +102,7 @@ impl CommitCommand {
let config = manager.config(); let config = manager.config();
let language = manager.get_language().unwrap_or(Language::English); let language = manager.get_language().unwrap_or(Language::English);
let messages = Messages::new(language); let messages = Messages::new(language);
// Check for changes // Check for changes
let status = repo.status_summary()?; let status = repo.status_summary()?;
if status.clean && !self.amend { if status.clean && !self.amend {
@@ -119,7 +123,7 @@ impl CommitCommand {
println!("{}", messages.auto_stage_changes().yellow()); println!("{}", messages.auto_stage_changes().yellow());
repo.stage_all()?; repo.stage_all()?;
println!("{}", messages.staged_all().green()); println!("{}", messages.staged_all().green());
// Re-check status after staging to ensure changes are detected // Re-check status after staging to ensure changes are detected
let new_status = repo.status_summary()?; let new_status = repo.status_summary()?;
if new_status.staged == 0 { if new_status.staged == 0 {
@@ -179,14 +183,22 @@ impl CommitCommand {
let result = if self.amend { let result = if self.amend {
if self.dry_run { if self.dry_run {
println!("\n{} {}", messages.dry_run(), "- commit not amended.".yellow()); println!(
"\n{} {}",
messages.dry_run(),
"- commit not amended.".yellow()
);
return Ok(()); return Ok(());
} }
self.amend_commit(&repo, &commit_message)?; self.amend_commit(&repo, &commit_message)?;
None None
} else { } else {
if self.dry_run { if self.dry_run {
println!("\n{} {}", messages.dry_run(), "- commit not created.".yellow()); println!(
"\n{} {}",
messages.dry_run(),
"- commit not created.".yellow()
);
return Ok(()); return Ok(());
} }
CommitBuilder::new() CommitBuilder::new()
@@ -196,9 +208,13 @@ impl CommitCommand {
}; };
if let Some(commit_oid) = result { if let Some(commit_oid) = result {
println!("{} {}", messages.commit_created().green().bold(), commit_oid.to_string()[..8].to_string().cyan()); println!(
"{} {}",
messages.commit_created().green().bold(),
commit_oid.to_string()[..8].to_string().cyan()
);
} else { } else {
println!("{} {}", messages.commit_amended().green().bold(), "successfully"); println!("{} successfully", messages.commit_amended().green().bold());
} }
// Push after commit if requested or ask user // Push after commit if requested or ask user
@@ -228,8 +244,9 @@ impl CommitCommand {
} }
fn create_manual_commit(&self, format: CommitFormat) -> Result<String> { fn create_manual_commit(&self, format: CommitFormat) -> Result<String> {
let description = self.message.clone() let description = self.message.clone().ok_or_else(|| {
.ok_or_else(|| anyhow::anyhow!("Description required for manual commit. Use -m <message>"))?; anyhow::anyhow!("Description required for manual commit. Use -m <message>")
})?;
// Try to extract commit type from message if not provided // Try to extract commit type from message if not provided
let commit_type = if let Some(ref ct) = self.commit_type { let commit_type = if let Some(ref ct) = self.commit_type {
@@ -255,31 +272,40 @@ impl CommitCommand {
builder.build_message() builder.build_message()
} }
async fn generate_commit(&self, repo: &GitRepo, format: CommitFormat, messages: &Messages) -> Result<String> { async fn generate_commit(
&self,
repo: &GitRepo,
format: CommitFormat,
messages: &Messages,
) -> Result<String> {
let manager = ConfigManager::new()?; let manager = ConfigManager::new()?;
let config = manager.config();
// Check if LLM is configured let generator = ContentGenerator::new_with_think(&manager, self.think)
let generator = ContentGenerator::new(&config.llm).await .await
.context("Failed to initialize LLM. Use --manual for manual commit.")?; .context("Failed to initialize LLM. Use --manual for manual commit.")?;
println!("{}", messages.ai_analyzing()); println!("{}", messages.ai_analyzing());
let language_str = &config.language.output_language; let language = manager.get_language().unwrap_or(Language::English);
let language = Language::from_str(language_str).unwrap_or(Language::English);
let generated = if self.yes { let generated = if self.yes {
// Non-interactive mode: generate directly generator
generator.generate_commit_from_repo(repo, format, language).await? .generate_commit_from_repo(repo, format, language)
.await?
} else { } else {
// Interactive mode: allow user to review and regenerate generator
generator.generate_commit_interactive(repo, format, language).await? .generate_commit_interactive(repo, format, language)
.await?
}; };
Ok(generated.to_conventional()) Ok(generated.to_conventional())
} }
async fn create_interactive_commit(&self, format: CommitFormat, messages: &Messages) -> Result<String> { async fn create_interactive_commit(
&self,
format: CommitFormat,
messages: &Messages,
) -> Result<String> {
let types = get_commit_types(format == CommitFormat::Commitlint); let types = get_commit_types(format == CommitFormat::Commitlint);
// Select type // Select type
@@ -357,20 +383,21 @@ impl CommitCommand {
if !output.status.success() { if !output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout); let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
let error_msg = if stderr.is_empty() { let error_msg = if stderr.is_empty() {
if stdout.is_empty() { if stdout.is_empty() {
"GPG signing failed. Please check:\n\ "GPG signing failed. Please check:\n\
1. GPG signing key is configured (git config --get user.signingkey)\n\ 1. GPG signing key is configured (git config --get user.signingkey)\n\
2. GPG agent is running\n\ 2. GPG agent is running\n\
3. You can sign commits manually (try: git commit --amend -S)".to_string() 3. You can sign commits manually (try: git commit --amend -S)"
.to_string()
} else { } else {
stdout.to_string() stdout.to_string()
} }
} else { } else {
stderr.to_string() stderr.to_string()
}; };
bail!("Failed to amend commit: {}", error_msg); bail!("Failed to amend commit: {}", error_msg);
} }
@@ -378,26 +405,26 @@ impl CommitCommand {
} }
} }
// Helper trait for optional builder methods // // Helper trait for optional builder methods
trait CommitBuilderExt { // trait CommitBuilderExt {
fn scope_opt(self, scope: Option<String>) -> Self; // fn scope_opt(self, scope: Option<String>) -> Self;
fn body_opt(self, body: Option<String>) -> Self; // fn body_opt(self, body: Option<String>) -> Self;
} // }
impl CommitBuilderExt for CommitBuilder { // impl CommitBuilderExt for CommitBuilder {
fn scope_opt(self, scope: Option<String>) -> Self { // fn scope_opt(self, scope: Option<String>) -> Self {
if let Some(s) = scope { // if let Some(s) = scope {
self.scope(s) // self.scope(s)
} else { // } else {
self // self
} // }
} // }
fn body_opt(self, body: Option<String>) -> Self { // fn body_opt(self, body: Option<String>) -> Self {
if let Some(b) = body { // if let Some(b) = body {
self.body(b) // self.body(b)
} else { // } else {
self // self
} // }
} // }
} // }

File diff suppressed because it is too large Load Diff

View File

@@ -4,10 +4,11 @@ use colored::Colorize;
use dialoguer::{Confirm, Input, Select}; use dialoguer::{Confirm, Input, Select};
use std::path::PathBuf; use std::path::PathBuf;
use crate::config::{GitProfile, Language};
use crate::config::manager::ConfigManager; use crate::config::manager::ConfigManager;
use crate::config::profile::{GpgConfig, SshConfig}; use crate::config::profile::{GpgConfig, SshConfig};
use crate::config::{GitProfile, Language};
use crate::i18n::Messages; use crate::i18n::Messages;
use crate::utils::keyring::{get_default_model, get_supported_providers, provider_needs_api_key};
use crate::utils::validators::validate_email; use crate::utils::validators::validate_email;
/// Initialize quicommit configuration /// Initialize quicommit configuration
@@ -27,35 +28,34 @@ impl InitCommand {
let messages = Messages::new(Language::English); let messages = Messages::new(Language::English);
println!("{}", messages.initializing().bold().cyan()); println!("{}", messages.initializing().bold().cyan());
let config_path = config_path.unwrap_or_else(|| { let config_path =
crate::config::AppConfig::default_path().unwrap() config_path.unwrap_or_else(|| crate::config::AppConfig::default_path().unwrap());
});
// Check if config already exists
if config_path.exists() && !self.reset { if config_path.exists() && !self.reset {
if !self.yes { if !self.yes {
let overwrite = Confirm::new() let overwrite = Confirm::new()
.with_prompt("Configuration already exists. Overwrite?") .with_prompt("Configuration already exists. Overwrite?")
.default(false) .default(false)
.interact()?; .interact()?;
if !overwrite { if !overwrite {
println!("{}", "Initialization cancelled.".yellow()); println!("{}", "Initialization cancelled.".yellow());
return Ok(()); return Ok(());
} }
} else { } else {
println!("{}", "Configuration already exists. Use --reset to overwrite.".yellow()); println!(
"{}",
"Configuration already exists. Use --reset to overwrite.".yellow()
);
return Ok(()); return Ok(());
} }
} }
// Create parent directory if needed
if let Some(parent) = config_path.parent() { if let Some(parent) = config_path.parent() {
std::fs::create_dir_all(parent) std::fs::create_dir_all(parent)
.map_err(|e| anyhow::anyhow!("Failed to create config directory: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to create config directory: {}", e))?;
} }
// Create new config manager with fresh config
let mut manager = ConfigManager::with_path_fresh(&config_path)?; let mut manager = ConfigManager::with_path_fresh(&config_path)?;
if self.yes { if self.yes {
@@ -65,11 +65,10 @@ impl InitCommand {
} }
manager.save()?; manager.save()?;
// Get configured language for final messages
let language = manager.get_language().unwrap_or(Language::English); let language = manager.get_language().unwrap_or(Language::English);
let messages = Messages::new(language); let messages = Messages::new(language);
println!("{}", messages.init_success().bold().green()); println!("{}", messages.init_success().bold().green());
println!("\n{}: {}", messages.config_file(), config_path.display()); println!("\n{}: {}", messages.config_file(), config_path.display());
println!("\n{}:", messages.next_steps()); println!("\n{}:", messages.next_steps());
@@ -81,22 +80,20 @@ impl InitCommand {
} }
async fn quick_setup(&self, manager: &mut ConfigManager) -> Result<()> { async fn quick_setup(&self, manager: &mut ConfigManager) -> Result<()> {
// Try to get git user info
let git_config = git2::Config::open_default()?; let git_config = git2::Config::open_default()?;
let user_name = git_config.get_string("user.name").unwrap_or_else(|_| "User".to_string());
let user_email = git_config.get_string("user.email").unwrap_or_else(|_| "user@example.com".to_string());
let profile = GitProfile::new( let user_name = git_config
"default".to_string(), .get_string("user.name")
user_name, .unwrap_or_else(|_| "User".to_string());
user_email, let user_email = git_config
); .get_string("user.email")
.unwrap_or_else(|_| "user@example.com".to_string());
let profile = GitProfile::new("default".to_string(), user_name, user_email);
manager.add_profile("default".to_string(), profile)?; manager.add_profile("default".to_string(), profile)?;
manager.set_default_profile(Some("default".to_string()))?; manager.set_default_profile(Some("default".to_string()))?;
// Set default LLM to Ollama
manager.set_llm_provider("ollama".to_string()); manager.set_llm_provider("ollama".to_string());
Ok(()) Ok(())
@@ -106,9 +103,8 @@ impl InitCommand {
let messages = Messages::new(Language::English); let messages = Messages::new(Language::English);
println!("\n{}", messages.setup_profile().bold()); println!("\n{}", messages.setup_profile().bold());
// Language selection
println!("\n{}", messages.select_output_language().bold()); println!("\n{}", messages.select_output_language().bold());
let languages = vec![ let languages = [
Language::English, Language::English,
Language::Chinese, Language::Chinese,
Language::Japanese, Language::Japanese,
@@ -117,32 +113,31 @@ impl InitCommand {
Language::French, Language::French,
Language::German, Language::German,
]; ];
let language_names: Vec<String> = languages.iter().map(|l| l.display_name().to_string()).collect(); let language_names: Vec<String> = languages
let language_idx = Select::new() .iter()
.items(&language_names) .map(|l| l.display_name().to_string())
.default(0) .collect();
.interact()?; let language_idx = Select::new().items(&language_names).default(0).interact()?;
let selected_language = languages[language_idx]; let selected_language = languages[language_idx];
manager.set_output_language(selected_language.to_code().to_string()); manager.set_output_language(selected_language.to_code().to_string());
// Update messages to selected language
let messages = Messages::new(selected_language); let messages = Messages::new(selected_language);
// Profile name
let profile_name: String = Input::new() let profile_name: String = Input::new()
.with_prompt(messages.profile_name()) .with_prompt(messages.profile_name())
.default("personal".to_string()) .default("personal".to_string())
.interact_text()?; .interact_text()?;
// User info
let git_config = git2::Config::open_default().ok(); let git_config = git2::Config::open_default().ok();
let default_name = git_config.as_ref() let default_name = git_config
.as_ref()
.and_then(|c| c.get_string("user.name").ok()) .and_then(|c| c.get_string("user.name").ok())
.unwrap_or_default(); .unwrap_or_default();
let default_email = git_config.as_ref() let default_email = git_config
.as_ref()
.and_then(|c| c.get_string("user.email").ok()) .and_then(|c| c.get_string("user.email").ok())
.unwrap_or_default(); .unwrap_or_default();
@@ -154,9 +149,7 @@ impl InitCommand {
let user_email: String = Input::new() let user_email: String = Input::new()
.with_prompt(messages.git_user_email()) .with_prompt(messages.git_user_email())
.default(default_email) .default(default_email)
.validate_with(|input: &String| { .validate_with(|input: &String| validate_email(input).map_err(|e| e.to_string()))
validate_email(input).map_err(|e| e.to_string())
})
.interact_text()?; .interact_text()?;
let description: String = Input::new() let description: String = Input::new()
@@ -170,14 +163,15 @@ impl InitCommand {
.interact()?; .interact()?;
let organization = if is_work { let organization = if is_work {
Some(Input::new() Some(
.with_prompt(messages.organization_name()) Input::new()
.interact_text()?) .with_prompt(messages.organization_name())
.interact_text()?,
)
} else { } else {
None None
}; };
// SSH configuration
let setup_ssh = Confirm::new() let setup_ssh = Confirm::new()
.with_prompt(messages.configure_ssh()) .with_prompt(messages.configure_ssh())
.default(false) .default(false)
@@ -189,7 +183,6 @@ impl InitCommand {
None None
}; };
// GPG configuration
let setup_gpg = Confirm::new() let setup_gpg = Confirm::new()
.with_prompt(messages.configure_gpg()) .with_prompt(messages.configure_gpg())
.default(false) .default(false)
@@ -201,12 +194,7 @@ impl InitCommand {
None None
}; };
// Create profile let mut profile = GitProfile::new(profile_name.clone(), user_name, user_email);
let mut profile = GitProfile::new(
profile_name.clone(),
user_name,
user_email,
);
if !description.is_empty() { if !description.is_empty() {
profile.description = Some(description); profile.description = Some(description);
@@ -220,59 +208,112 @@ impl InitCommand {
manager.add_profile(profile_name.clone(), profile)?; manager.add_profile(profile_name.clone(), profile)?;
manager.set_default_profile(Some(profile_name))?; manager.set_default_profile(Some(profile_name))?;
// LLM provider selection
println!("\n{}", messages.select_llm_provider().bold()); println!("\n{}", messages.select_llm_provider().bold());
let providers = vec![
let provider_display_names = vec![
"Ollama (local)", "Ollama (local)",
"OpenAI", "OpenAI",
"Anthropic Claude", "Anthropic Claude",
"Kimi (Moonshot AI)", "Kimi (Moonshot AI)",
"DeepSeek", "DeepSeek",
"OpenRouter" "OpenRouter",
]; ];
let provider_idx = Select::new() let provider_idx = Select::new()
.items(&providers) .items(&provider_display_names)
.default(0) .default(0)
.interact()?; .interact()?;
let provider = match provider_idx { let providers = get_supported_providers();
0 => "ollama", let provider = providers[provider_idx].to_string();
1 => "openai",
2 => "anthropic", let keyring = manager.keyring();
3 => "kimi", let keyring_available = keyring.is_available();
4 => "deepseek",
5 => "openrouter", if !keyring_available {
_ => "ollama", println!(
"\n{}",
"⚠ Keyring is not available on this system.".yellow()
);
println!("{}", keyring.get_status_message().yellow());
}
let api_key = if provider_needs_api_key(&provider) {
let env_key = std::env::var("QUICOMMIT_API_KEY")
.or_else(|_| {
std::env::var(format!("QUICOMMIT_{}_API_KEY", provider.to_uppercase()))
})
.ok();
if let Some(_key) = env_key {
println!(
"\n{} {}",
"".green(),
"Found API key in environment variable.".green()
);
None
} else if keyring_available {
let prompt = match provider.as_str() {
"openai" => messages.openai_api_key(),
"anthropic" => messages.anthropic_api_key(),
"kimi" => messages.kimi_api_key(),
"deepseek" => messages.deepseek_api_key(),
"openrouter" => messages.openrouter_api_key(),
_ => "API Key",
};
let key: String = Input::new().with_prompt(prompt).interact_text()?;
Some(key)
} else {
println!(
"\n{}",
"Please set the QUICOMMIT_API_KEY environment variable.".yellow()
);
None
}
} else {
None
}; };
manager.set_llm_provider(provider.to_string()); let default_model = get_default_model(&provider);
let model: String = Input::new()
.with_prompt("Model name")
.default(default_model.to_string())
.interact_text()?;
// Configure API key if needed let base_url: Option<String> = if provider == "ollama" {
if provider == "openai" { let url: String = Input::new()
let api_key: String = Input::new() .with_prompt("Ollama server URL")
.with_prompt(messages.openai_api_key()) .default("http://localhost:11434".to_string())
.interact_text()?; .interact_text()?;
manager.set_openai_api_key(api_key); Some(url)
} else if provider == "anthropic" { } else {
let api_key: String = Input::new() let use_custom_url = Confirm::new()
.with_prompt(messages.anthropic_api_key()) .with_prompt("Use custom API base URL?")
.interact_text()?; .default(false)
manager.set_anthropic_api_key(api_key); .interact()?;
} else if provider == "kimi" {
let api_key: String = Input::new() if use_custom_url {
.with_prompt(messages.kimi_api_key()) let url: String = Input::new().with_prompt("Base URL").interact_text()?;
.interact_text()?; Some(url)
manager.set_kimi_api_key(api_key); } else {
} else if provider == "deepseek" { None
let api_key: String = Input::new() }
.with_prompt(messages.deepseek_api_key()) };
.interact_text()?;
manager.set_deepseek_api_key(api_key); manager.set_llm_provider(provider.clone());
} else if provider == "openrouter" { manager.set_llm_model(model);
let api_key: String = Input::new() manager.set_llm_base_url(base_url);
.with_prompt(messages.openrouter_api_key())
.interact_text()?; if let Some(key) = api_key
manager.set_openrouter_api_key(api_key); && provider_needs_api_key(&provider)
{
manager.set_api_key(&key)?;
println!(
"\n{} {}",
"".green(),
"API key stored securely in system keyring.".green()
);
} }
Ok(()) Ok(())

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
use anyhow::{bail, Result}; use anyhow::{Result, bail};
use clap::Parser; use clap::Parser;
use colored::Colorize; use colored::Colorize;
use dialoguer::{Confirm, Input, Select}; use dialoguer::{Confirm, Input, Select};
@@ -6,11 +6,11 @@ use semver::Version;
use std::path::PathBuf; use std::path::PathBuf;
use crate::config::{Language, manager::ConfigManager}; use crate::config::{Language, manager::ConfigManager};
use crate::git::{find_repo, GitRepo};
use crate::generator::ContentGenerator; use crate::generator::ContentGenerator;
use crate::git::tag::{ use crate::git::tag::{
bump_version, get_latest_version, suggest_version_bump, TagBuilder, VersionBump, TagBuilder, VersionBump, bump_version, get_latest_version, suggest_version_bump,
}; };
use crate::git::{GitRepo, find_repo};
use crate::i18n::Messages; use crate::i18n::Messages;
/// Generate and create Git tags /// Generate and create Git tags
@@ -56,6 +56,10 @@ pub struct TagCommand {
#[arg(long)] #[arg(long)]
dry_run: bool, dry_run: bool,
/// Enable thinking mode for this tag (override config)
#[arg(short = 't', long)]
think: bool,
/// Skip interactive prompts /// Skip interactive prompts
#[arg(short = 'y', long)] #[arg(short = 'y', long)]
yes: bool, yes: bool,
@@ -79,30 +83,37 @@ impl TagCommand {
} else if let Some(bump_str) = &self.bump { } else if let Some(bump_str) = &self.bump {
// Calculate bumped version // Calculate bumped version
let prefix = &config.tag.version_prefix; let prefix = &config.tag.version_prefix;
let latest = get_latest_version(&repo, prefix)? let latest =
.unwrap_or_else(|| Version::new(0, 0, 0)); get_latest_version(&repo, prefix)?.unwrap_or_else(|| Version::new(0, 0, 0));
let bump = VersionBump::from_str(bump_str)?; let bump = VersionBump::from_str(bump_str)?;
let new_version = bump_version(&latest, bump, None); let new_version = bump_version(&latest, bump, None);
format!("{}{}", prefix, new_version) format!("{}{}", prefix, new_version)
} else { } else {
// Interactive mode // Interactive mode
self.select_version_interactive(&repo, &config.tag.version_prefix, &messages).await? self.select_version_interactive(&repo, &config.tag.version_prefix, &messages)
.await?
}; };
// Validate tag name (if it looks like a version) // Validate tag name (if it looks like a version)
if tag_name.starts_with('v') || tag_name.chars().next().map(|c| c.is_ascii_digit()).unwrap_or(false) { if tag_name.starts_with('v')
|| tag_name
.chars()
.next()
.map(|c| c.is_ascii_digit())
.unwrap_or(false)
{
let version_str = tag_name.trim_start_matches('v'); let version_str = tag_name.trim_start_matches('v');
if let Err(e) = crate::utils::validators::validate_semver(version_str) { if let Err(e) = crate::utils::validators::validate_semver(version_str) {
println!("{}: {}", "Warning".yellow(), e); println!("{}: {}", "Warning".yellow(), e);
if !self.yes { if !self.yes {
let proceed = Confirm::new() let proceed = Confirm::new()
.with_prompt("Proceed with this tag name anyway?") .with_prompt("Proceed with this tag name anyway?")
.default(true) .default(true)
.interact()?; .interact()?;
if !proceed { if !proceed {
bail!("{}", messages.tag_cancelled()); bail!("{}", messages.tag_cancelled());
} }
@@ -116,7 +127,10 @@ impl TagCommand {
} else if let Some(msg) = &self.message { } else if let Some(msg) = &self.message {
Some(msg.clone()) Some(msg.clone())
} else if self.generate || (config.tag.auto_generate && !self.yes) { } else if self.generate || (config.tag.auto_generate && !self.yes) {
Some(self.generate_tag_message(&repo, &tag_name, &messages).await?) Some(
self.generate_tag_message(&repo, &tag_name, &messages)
.await?,
)
} else if !self.yes { } else if !self.yes {
Some(self.input_message_interactive(&tag_name, &messages)?) Some(self.input_message_interactive(&tag_name, &messages)?)
} else { } else {
@@ -184,12 +198,17 @@ impl TagCommand {
Ok(()) Ok(())
} }
async fn select_version_interactive(&self, repo: &GitRepo, prefix: &str, messages: &Messages) -> Result<String> { async fn select_version_interactive(
&self,
repo: &GitRepo,
prefix: &str,
messages: &Messages,
) -> Result<String> {
loop { loop {
let latest = get_latest_version(repo, prefix)?; let latest = get_latest_version(repo, prefix)?;
println!("\n{}", messages.version_selection().bold()); println!("\n{}", messages.version_selection().bold());
if let Some(ref version) = latest { if let Some(ref version) = latest {
println!("{} {}{}", messages.latest_version(), prefix, version); println!("{} {}{}", messages.latest_version(), prefix, version);
} else { } else {
@@ -216,36 +235,46 @@ impl TagCommand {
// Auto-detect // Auto-detect
let commits = repo.get_commits(50)?; let commits = repo.get_commits(50)?;
let bump = suggest_version_bump(&commits); let bump = suggest_version_bump(&commits);
let version = latest.as_ref() let version = latest
.as_ref()
.map(|v| bump_version(v, bump, None)) .map(|v| bump_version(v, bump, None))
.unwrap_or_else(|| Version::new(0, 1, 0)); .unwrap_or_else(|| Version::new(0, 1, 0));
println!("{} {:?}{}{}", messages.suggested_bump(), bump, prefix, version); println!(
"{} {:?}{}{}",
messages.suggested_bump(),
bump,
prefix,
version
);
let confirm = Confirm::new() let confirm = Confirm::new()
.with_prompt(messages.use_this_version()) .with_prompt(messages.use_this_version())
.default(true) .default(true)
.interact()?; .interact()?;
if confirm { if confirm {
return Ok(format!("{}{}", prefix, version)); return Ok(format!("{}{}", prefix, version));
} }
// User rejected, continue the loop // User rejected, continue the loop
} }
1 => { 1 => {
let version = latest.as_ref() let version = latest
.as_ref()
.map(|v| bump_version(v, VersionBump::Major, None)) .map(|v| bump_version(v, VersionBump::Major, None))
.unwrap_or_else(|| Version::new(1, 0, 0)); .unwrap_or_else(|| Version::new(1, 0, 0));
return Ok(format!("{}{}", prefix, version)); return Ok(format!("{}{}", prefix, version));
} }
2 => { 2 => {
let version = latest.as_ref() let version = latest
.as_ref()
.map(|v| bump_version(v, VersionBump::Minor, None)) .map(|v| bump_version(v, VersionBump::Minor, None))
.unwrap_or_else(|| Version::new(0, 1, 0)); .unwrap_or_else(|| Version::new(0, 1, 0));
return Ok(format!("{}{}", prefix, version)); return Ok(format!("{}{}", prefix, version));
} }
3 => { 3 => {
let version = latest.as_ref() let version = latest
.as_ref()
.map(|v| bump_version(v, VersionBump::Patch, None)) .map(|v| bump_version(v, VersionBump::Patch, None))
.unwrap_or_else(|| Version::new(0, 0, 1)); .unwrap_or_else(|| Version::new(0, 0, 1));
return Ok(format!("{}{}", prefix, version)); return Ok(format!("{}{}", prefix, version));
@@ -268,12 +297,15 @@ impl TagCommand {
} }
} }
async fn generate_tag_message(&self, repo: &GitRepo, version: &str, messages: &Messages) -> Result<String> { async fn generate_tag_message(
&self,
repo: &GitRepo,
version: &str,
messages: &Messages,
) -> Result<String> {
let manager = ConfigManager::new()?; let manager = ConfigManager::new()?;
let config = manager.config();
let language = manager.get_language().unwrap_or(Language::English); let language = manager.get_language().unwrap_or(Language::English);
// Get commits since last tag
let tags = repo.get_tags()?; let tags = repo.get_tags()?;
let commits = if let Some(latest_tag) = tags.first() { let commits = if let Some(latest_tag) = tags.first() {
repo.get_commits_between(&latest_tag.name, "HEAD")? repo.get_commits_between(&latest_tag.name, "HEAD")?
@@ -287,18 +319,20 @@ impl TagCommand {
println!("{}", messages.ai_generating_tag(commits.len())); println!("{}", messages.ai_generating_tag(commits.len()));
let generator = ContentGenerator::new(&config.llm).await?; let generator = ContentGenerator::new_with_think(&manager, self.think).await?;
generator.generate_tag_message(version, &commits, language).await generator
.generate_tag_message(version, &commits, language)
.await
} }
fn input_message_interactive(&self, version: &str, messages: &Messages) -> Result<String> { fn input_message_interactive(&self, version: &str, messages: &Messages) -> Result<String> {
let default_msg = format!("Release {}", version); let default_msg = format!("Release {}", version);
let use_editor = Confirm::new() let use_editor = Confirm::new()
.with_prompt(messages.open_editor()) .with_prompt(messages.open_editor())
.default(false) .default(false)
.interact()?; .interact()?;
if use_editor { if use_editor {
crate::utils::editor::edit_content(&default_msg) crate::utils::editor::edit_content(&default_msg)
} else { } else {

View File

@@ -1,6 +1,9 @@
use super::{AppConfig, GitProfile, TokenConfig}; use super::{AppConfig, GitProfile, TokenConfig};
use anyhow::{bail, Context, Result}; use crate::utils::keyring::{
use std::collections::HashMap; KeyringManager, get_default_base_url, get_default_model, provider_needs_api_key,
};
use anyhow::{Context, Result, bail};
// use std::collections::HashMap;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
/// Configuration manager /// Configuration manager
@@ -8,6 +11,7 @@ pub struct ConfigManager {
config: AppConfig, config: AppConfig,
config_path: PathBuf, config_path: PathBuf,
modified: bool, modified: bool,
keyring: KeyringManager,
} }
impl ConfigManager { impl ConfigManager {
@@ -28,6 +32,7 @@ impl ConfigManager {
config, config,
config_path: path.to_path_buf(), config_path: path.to_path_buf(),
modified: false, modified: false,
keyring: KeyringManager::new(),
}) })
} }
@@ -37,6 +42,7 @@ impl ConfigManager {
config: AppConfig::default(), config: AppConfig::default(),
config_path: path.to_path_buf(), config_path: path.to_path_buf(),
modified: true, modified: true,
keyring: KeyringManager::new(),
}) })
} }
@@ -60,10 +66,10 @@ impl ConfigManager {
Ok(()) Ok(())
} }
/// Force save configuration // /// Force save configuration
pub fn force_save(&self) -> Result<()> { // pub fn force_save(&self) -> Result<()> {
self.config.save(&self.config_path) // self.config.save(&self.config_path)
} // }
/// Get configuration file path /// Get configuration file path
pub fn path(&self) -> &Path { pub fn path(&self) -> &Path {
@@ -87,13 +93,13 @@ impl ConfigManager {
if !self.config.profiles.contains_key(name) { if !self.config.profiles.contains_key(name) {
bail!("Profile '{}' does not exist", name); bail!("Profile '{}' does not exist", name);
} }
if self.config.default_profile.as_ref() == Some(&name.to_string()) { if self.config.default_profile.as_ref() == Some(&name.to_string()) {
self.config.default_profile = None; self.config.default_profile = None;
} }
self.config.repo_profiles.retain(|_, v| v != name); self.config.repo_profiles.retain(|_, v| v != name);
self.config.profiles.remove(name); self.config.profiles.remove(name);
self.modified = true; self.modified = true;
Ok(()) Ok(())
@@ -114,11 +120,11 @@ impl ConfigManager {
self.config.profiles.get(name) self.config.profiles.get(name)
} }
/// Get mutable profile // /// Get mutable profile
pub fn get_profile_mut(&mut self, name: &str) -> Option<&mut GitProfile> { // pub fn get_profile_mut(&mut self, name: &str) -> Option<&mut GitProfile> {
self.modified = true; // self.modified = true;
self.config.profiles.get_mut(name) // self.config.profiles.get_mut(name)
} // }
/// List all profile names /// List all profile names
pub fn list_profiles(&self) -> Vec<&String> { pub fn list_profiles(&self) -> Vec<&String> {
@@ -132,10 +138,10 @@ impl ConfigManager {
/// Set default profile /// Set default profile
pub fn set_default_profile(&mut self, name: Option<String>) -> Result<()> { pub fn set_default_profile(&mut self, name: Option<String>) -> Result<()> {
if let Some(ref n) = name { if let Some(ref n) = name
if !self.config.profiles.contains_key(n) { && !self.config.profiles.contains_key(n)
bail!("Profile '{}' does not exist", n); {
} bail!("Profile '{}' does not exist", n);
} }
self.config.default_profile = name; self.config.default_profile = name;
self.modified = true; self.modified = true;
@@ -166,54 +172,138 @@ impl ConfigManager {
} }
} }
/// Get profile usage statistics // /// Get profile usage statistics
pub fn get_profile_usage(&self, name: &str) -> Option<&super::UsageStats> { // pub fn get_profile_usage(&self, name: &str) -> Option<&super::UsageStats> {
self.config.profiles.get(name).map(|p| &p.usage) // self.config.profiles.get(name).map(|p| &p.usage)
} // }
// Token management // Token management
/// Add a token to a profile /// Add a token to a profile (stores token in keyring)
pub fn add_token_to_profile(&mut self, profile_name: &str, service: String, token: TokenConfig) -> Result<()> { pub fn add_token_to_profile(
&mut self,
profile_name: &str,
service: String,
token: TokenConfig,
) -> Result<()> {
if !self.config.profiles.contains_key(profile_name) {
bail!("Profile '{}' does not exist", profile_name);
}
if let Some(profile) = self.config.profiles.get_mut(profile_name) { if let Some(profile) = self.config.profiles.get_mut(profile_name) {
profile.add_token(service, token); profile.add_token(service, token);
self.modified = true; self.modified = true;
Ok(()) }
Ok(())
}
/// Store a PAT token in keyring for a profile
pub fn store_pat_for_profile(
&self,
profile_name: &str,
service: &str,
token_value: &str,
) -> Result<()> {
let profile = self
.get_profile(profile_name)
.ok_or_else(|| anyhow::anyhow!("Profile '{}' not found", profile_name))?;
let user_email = &profile.user_email;
self.keyring
.store_pat(profile_name, user_email, service, token_value)
}
/// Get a PAT token from keyring for a profile
pub fn get_pat_for_profile(&self, profile_name: &str, service: &str) -> Result<Option<String>> {
let profile = self
.get_profile(profile_name)
.ok_or_else(|| anyhow::anyhow!("Profile '{}' not found", profile_name))?;
let user_email = &profile.user_email;
self.keyring.get_pat(profile_name, user_email, service)
}
/// Check if a PAT token exists for a profile
pub fn has_pat_for_profile(&self, profile_name: &str, service: &str) -> bool {
if let Some(profile) = self.get_profile(profile_name) {
let user_email = &profile.user_email;
self.keyring.has_pat(profile_name, user_email, service)
} else { } else {
bail!("Profile '{}' does not exist", profile_name); false
} }
} }
/// Get a token from a profile /// Remove a token from a profile (deletes from keyring)
pub fn get_token_from_profile(&self, profile_name: &str, service: &str) -> Option<&TokenConfig> {
self.config.profiles.get(profile_name)?.get_token(service)
}
/// Remove a token from a profile
pub fn remove_token_from_profile(&mut self, profile_name: &str, service: &str) -> Result<()> { pub fn remove_token_from_profile(&mut self, profile_name: &str, service: &str) -> Result<()> {
if !self.config.profiles.contains_key(profile_name) {
bail!("Profile '{}' does not exist", profile_name);
}
let user_email = self
.config
.profiles
.get(profile_name)
.unwrap()
.user_email
.clone();
let services: Vec<String> = self
.config
.profiles
.get(profile_name)
.unwrap()
.tokens
.keys()
.cloned()
.collect();
if !services.contains(&service.to_string()) {
bail!(
"Token for service '{}' not found in profile '{}'",
service,
profile_name
);
}
self.keyring
.delete_pat(profile_name, &user_email, service)?;
if let Some(profile) = self.config.profiles.get_mut(profile_name) { if let Some(profile) = self.config.profiles.get_mut(profile_name) {
profile.remove_token(service); profile.remove_token(service);
self.modified = true; self.modified = true;
Ok(())
} else {
bail!("Profile '{}' does not exist", profile_name);
} }
Ok(())
} }
/// List all tokens in a profile /// Delete all PAT tokens for a profile (used when removing a profile)
pub fn list_profile_tokens(&self, profile_name: &str) -> Option<Vec<&String>> { pub fn delete_all_pats_for_profile(&self, profile_name: &str) -> Result<()> {
self.config.profiles.get(profile_name).map(|p| p.tokens.keys().collect()) if let Some(profile) = self.get_profile(profile_name) {
let user_email = &profile.user_email;
let services: Vec<String> = profile.tokens.keys().cloned().collect();
self.keyring
.delete_all_pats_for_profile(profile_name, user_email, &services)?;
}
Ok(())
} }
// /// List all tokens in a profile
// pub fn list_profile_tokens(&self, profile_name: &str) -> Option<Vec<&String>> {
// self.config.profiles.get(profile_name).map(|p| p.tokens.keys().collect())
// }
// Repository profile management // Repository profile management
/// Get profile for repository // /// Get profile for repository
pub fn get_repo_profile(&self, repo_path: &str) -> Option<&GitProfile> { // pub fn get_repo_profile(&self, repo_path: &str) -> Option<&GitProfile> {
self.config // self.config
.repo_profiles // .repo_profiles
.get(repo_path) // .get(repo_path)
.and_then(|name| self.config.profiles.get(name)) // .and_then(|name| self.config.profiles.get(name))
} // }
/// Set profile for repository /// Set profile for repository
pub fn set_repo_profile(&mut self, repo_path: String, profile_name: String) -> Result<()> { pub fn set_repo_profile(&mut self, repo_path: String, profile_name: String) -> Result<()> {
@@ -225,32 +315,75 @@ impl ConfigManager {
Ok(()) Ok(())
} }
/// Remove repository profile mapping // /// Remove repository profile mapping
pub fn remove_repo_profile(&mut self, repo_path: &str) { // pub fn remove_repo_profile(&mut self, repo_path: &str) {
self.config.repo_profiles.remove(repo_path); // self.config.repo_profiles.remove(repo_path);
self.modified = true; // self.modified = true;
// }
// /// List repository profile mappings
// pub fn list_repo_profiles(&self) -> &HashMap<String, String> {
// &self.config.repo_profiles
// }
// /// Get effective profile for a repository (repo-specific -> default)
// pub fn get_effective_profile(&self, repo_path: Option<&str>) -> Option<&GitProfile> {
// if let Some(path) = repo_path {
// if let Some(profile) = self.get_repo_profile(path) {
// return Some(profile);
// }
// }
// self.default_profile()
// }
/// Check and compare profile with git configuration
pub fn check_profile_config(
&self,
profile_name: &str,
repo: &git2::Repository,
) -> Result<super::ProfileComparison> {
let profile = self
.get_profile(profile_name)
.ok_or_else(|| anyhow::anyhow!("Profile '{}' not found", profile_name))?;
profile.compare_with_git_config(repo)
} }
/// List repository profile mappings /// Find a profile that matches the given user config (name, email, signing_key)
pub fn list_repo_profiles(&self) -> &HashMap<String, String> { pub fn find_matching_profile(
&self.config.repo_profiles &self,
} user_name: &str,
user_email: &str,
signing_key: Option<&str>,
) -> Option<&GitProfile> {
for profile in self.config.profiles.values() {
let name_match = profile.user_name == user_name;
let email_match = profile.user_email == user_email;
let key_match = match (signing_key, profile.signing_key()) {
(Some(git_key), Some(profile_key)) => git_key == profile_key,
(None, None) => true,
(Some(_), None) => false,
(None, Some(_)) => false,
};
/// Get effective profile for a repository (repo-specific -> default) if name_match && email_match && key_match {
pub fn get_effective_profile(&self, repo_path: Option<&str>) -> Option<&GitProfile> {
if let Some(path) = repo_path {
if let Some(profile) = self.get_repo_profile(path) {
return Some(profile); return Some(profile);
} }
} }
self.default_profile() None
} }
/// Check and compare profile with git configuration /// Find profiles that partially match (same name or same email)
pub fn check_profile_config(&self, profile_name: &str, repo: &git2::Repository) -> Result<super::ProfileComparison> { pub fn find_partial_matches(&self, user_name: &str, user_email: &str) -> Vec<&GitProfile> {
let profile = self.get_profile(profile_name) self.config
.ok_or_else(|| anyhow::anyhow!("Profile '{}' not found", profile_name))?; .profiles
profile.compare_with_git_config(repo) .values()
.filter(|p| p.user_name == user_name || p.user_email == user_email)
.collect()
}
/// Get repo profile mapping
pub fn get_repo_profile_name(&self, repo_path: &str) -> Option<&String> {
self.config.repo_profiles.get(repo_path)
} }
// LLM configuration // LLM configuration
@@ -262,104 +395,169 @@ impl ConfigManager {
/// Set LLM provider /// Set LLM provider
pub fn set_llm_provider(&mut self, provider: String) { pub fn set_llm_provider(&mut self, provider: String) {
self.config.llm.provider = provider; let default_model = get_default_model(&provider);
self.config.llm.provider = provider.clone();
if self.config.llm.model.is_empty() || self.config.llm.model == "llama2" {
self.config.llm.model = default_model.to_string();
}
self.modified = true; self.modified = true;
} }
/// Get OpenAI API key /// Get model
pub fn openai_api_key(&self) -> Option<&String> { pub fn llm_model(&self) -> &str {
self.config.llm.openai.api_key.as_ref() &self.config.llm.model
} }
/// Set OpenAI API key /// Set model
pub fn set_openai_api_key(&mut self, key: String) { pub fn set_llm_model(&mut self, model: String) {
self.config.llm.openai.api_key = Some(key); self.config.llm.model = model;
self.modified = true; self.modified = true;
} }
/// Get Anthropic API key /// Get base URL (returns provider default if not set)
pub fn anthropic_api_key(&self) -> Option<&String> { pub fn llm_base_url(&self) -> String {
self.config.llm.anthropic.api_key.as_ref() match &self.config.llm.base_url {
Some(url) => url.clone(),
None => get_default_base_url(&self.config.llm.provider).to_string(),
}
} }
/// Set Anthropic API key /// Set base URL
pub fn set_anthropic_api_key(&mut self, key: String) { pub fn set_llm_base_url(&mut self, url: Option<String>) {
self.config.llm.anthropic.api_key = Some(key); self.config.llm.base_url = url;
self.modified = true; self.modified = true;
} }
/// Get Kimi API key /// Get API key from configured storage method
pub fn kimi_api_key(&self) -> Option<&String> { pub fn get_api_key(&self) -> Option<String> {
self.config.llm.kimi.api_key.as_ref() // First try environment variables (always checked)
if let Some(key) = self
.keyring
.get_api_key(&self.config.llm.provider)
.unwrap_or(None)
{
return Some(key);
}
// Then try config file if configured
if self.config.llm.api_key_storage == "config" {
return self.config.llm.api_key.clone();
}
None
} }
/// Set Kimi API key /// Store API key in configured storage method
pub fn set_kimi_api_key(&mut self, key: String) { pub fn set_api_key(&self, api_key: &str) -> Result<()> {
self.config.llm.kimi.api_key = Some(key); match self.config.llm.api_key_storage.as_str() {
self.modified = true; "keyring" => {
if !self.keyring.is_available() {
bail!(
"Keyring is not available. Set QUICOMMIT_API_KEY environment variable instead or change api_key_storage to 'config'."
);
}
self.keyring
.store_api_key(&self.config.llm.provider, api_key)
}
"config" => {
// We can't modify self.config here since self is immutable
// This will be handled by the caller updating the config
Ok(())
}
"environment" => {
bail!(
"API key storage set to 'environment'. Please set QUICOMMIT_{}_API_KEY environment variable.",
self.config.llm.provider.to_uppercase()
);
}
_ => {
bail!(
"Invalid API key storage method: {}",
self.config.llm.api_key_storage
);
}
}
} }
/// Get Kimi base URL /// Delete API key from configured storage method
pub fn kimi_base_url(&self) -> &str { pub fn delete_api_key(&self) -> Result<()> {
&self.config.llm.kimi.base_url match self.config.llm.api_key_storage.as_str() {
"keyring" => {
if self.keyring.is_available() {
self.keyring.delete_api_key(&self.config.llm.provider)?;
}
}
"config" => {
// We can't modify self.config here since self is immutable
// This will be handled by the caller updating the config
}
"environment" => {
// Environment variables are not managed by the app
}
_ => {
bail!(
"Invalid API key storage method: {}",
self.config.llm.api_key_storage
);
}
}
Ok(())
} }
/// Set Kimi base URL /// Check if API key is configured
pub fn set_kimi_base_url(&mut self, url: String) { pub fn has_api_key(&self) -> bool {
self.config.llm.kimi.base_url = url; if !provider_needs_api_key(&self.config.llm.provider) {
self.modified = true; return true;
}
// Check environment variables
if self
.keyring
.get_api_key(&self.config.llm.provider)
.unwrap_or(None)
.is_some()
{
return true;
}
// Check config file if configured
if self.config.llm.api_key_storage == "config" {
return self.config.llm.api_key.is_some();
}
false
} }
/// Get DeepSeek API key /// Get keyring manager reference
pub fn deepseek_api_key(&self) -> Option<&String> { pub fn keyring(&self) -> &KeyringManager {
self.config.llm.deepseek.api_key.as_ref() &self.keyring
} }
/// Set DeepSeek API key // /// Configure LLM provider with all settings
pub fn set_deepseek_api_key(&mut self, key: String) { // pub fn configure_llm(&mut self, provider: String, model: Option<String>, base_url: Option<String>, api_key: Option<&str>) -> Result<()> {
self.config.llm.deepseek.api_key = Some(key); // self.set_llm_provider(provider.clone());
self.modified = true;
}
/// Get DeepSeek base URL // if let Some(m) = model {
pub fn deepseek_base_url(&self) -> &str { // self.set_llm_model(m);
&self.config.llm.deepseek.base_url // }
}
/// Set DeepSeek base URL // self.set_llm_base_url(base_url);
pub fn set_deepseek_base_url(&mut self, url: String) {
self.config.llm.deepseek.base_url = url;
self.modified = true;
}
/// Get OpenRouter API key // if let Some(key) = api_key {
pub fn openrouter_api_key(&self) -> Option<&String> { // if provider_needs_api_key(&provider) {
self.config.llm.openrouter.api_key.as_ref() // self.set_api_key(key)?;
} // }
// }
/// Set OpenRouter API key // Ok(())
pub fn set_openrouter_api_key(&mut self, key: String) { // }
self.config.llm.openrouter.api_key = Some(key);
self.modified = true;
}
/// Get OpenRouter base URL
pub fn openrouter_base_url(&self) -> &str {
&self.config.llm.openrouter.base_url
}
/// Set OpenRouter base URL
pub fn set_openrouter_base_url(&mut self, url: String) {
self.config.llm.openrouter.base_url = url;
self.modified = true;
}
// Commit configuration // Commit configuration
/// Get commit format // /// Get commit format
pub fn commit_format(&self) -> super::CommitFormat { // pub fn commit_format(&self) -> super::CommitFormat {
self.config.commit.format // self.config.commit.format
} // }
/// Set commit format /// Set commit format
pub fn set_commit_format(&mut self, format: super::CommitFormat) { pub fn set_commit_format(&mut self, format: super::CommitFormat) {
@@ -367,10 +565,10 @@ impl ConfigManager {
self.modified = true; self.modified = true;
} }
/// Check if auto-generate is enabled // /// Check if auto-generate is enabled
pub fn auto_generate_commits(&self) -> bool { // pub fn auto_generate_commits(&self) -> bool {
self.config.commit.auto_generate // self.config.commit.auto_generate
} // }
/// Set auto-generate commits /// Set auto-generate commits
pub fn set_auto_generate_commits(&mut self, enabled: bool) { pub fn set_auto_generate_commits(&mut self, enabled: bool) {
@@ -380,10 +578,10 @@ impl ConfigManager {
// Tag configuration // Tag configuration
/// Get version prefix // /// Get version prefix
pub fn version_prefix(&self) -> &str { // pub fn version_prefix(&self) -> &str {
&self.config.tag.version_prefix // &self.config.tag.version_prefix
} // }
/// Set version prefix /// Set version prefix
pub fn set_version_prefix(&mut self, prefix: String) { pub fn set_version_prefix(&mut self, prefix: String) {
@@ -393,10 +591,10 @@ impl ConfigManager {
// Changelog configuration // Changelog configuration
/// Get changelog path // /// Get changelog path
pub fn changelog_path(&self) -> &str { // pub fn changelog_path(&self) -> &str {
&self.config.changelog.path // &self.config.changelog.path
} // }
/// Set changelog path /// Set changelog path
pub fn set_changelog_path(&mut self, path: String) { pub fn set_changelog_path(&mut self, path: String) {
@@ -406,10 +604,10 @@ impl ConfigManager {
// Language configuration // Language configuration
/// Get output language // /// Get output language
pub fn output_language(&self) -> &str { // pub fn output_language(&self) -> &str {
&self.config.language.output_language // &self.config.language.output_language
} // }
/// Set output language /// Set output language
pub fn set_output_language(&mut self, language: String) { pub fn set_output_language(&mut self, language: String) {
@@ -446,14 +644,12 @@ impl ConfigManager {
/// Export configuration to TOML string /// Export configuration to TOML string
pub fn export(&self) -> Result<String> { pub fn export(&self) -> Result<String> {
toml::to_string_pretty(&self.config) toml::to_string_pretty(&self.config).context("Failed to serialize config")
.context("Failed to serialize config")
} }
/// Import configuration from TOML string /// Import configuration from TOML string
pub fn import(&mut self, toml_str: &str) -> Result<()> { pub fn import(&mut self, toml_str: &str) -> Result<()> {
self.config = toml::from_str(toml_str) self.config = toml::from_str(toml_str).context("Failed to parse config")?;
.context("Failed to parse config")?;
self.modified = true; self.modified = true;
Ok(()) Ok(())
} }
@@ -471,6 +667,7 @@ impl Default for ConfigManager {
config: AppConfig::default(), config: AppConfig::default(),
config_path: PathBuf::new(), config_path: PathBuf::new(),
modified: false, modified: false,
keyring: KeyringManager::new(),
} }
} }
} }

View File

@@ -7,10 +7,7 @@ use std::path::{Path, PathBuf};
pub mod manager; pub mod manager;
pub mod profile; pub mod profile;
pub use profile::{ pub use profile::{GitProfile, ProfileComparison, TokenConfig, TokenType};
GitProfile, TokenConfig, TokenType,
UsageStats, ProfileComparison
};
/// Application configuration /// Application configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -80,37 +77,16 @@ impl Default for AppConfig {
/// LLM configuration /// LLM configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmConfig { pub struct LlmConfig {
/// Default LLM provider /// Current LLM provider (ollama, openai, anthropic, kimi, deepseek, openrouter)
#[serde(default = "default_llm_provider")] #[serde(default = "default_llm_provider")]
pub provider: String, pub provider: String,
/// OpenAI configuration /// Model to use (stored in config, not in keyring)
#[serde(default)] #[serde(default = "default_model")]
pub openai: OpenAiConfig, pub model: String,
/// Ollama configuration /// API base URL (optional, will use provider default if not set)
#[serde(default)] pub base_url: Option<String>,
pub ollama: OllamaConfig,
/// Anthropic Claude configuration
#[serde(default)]
pub anthropic: AnthropicConfig,
/// Kimi (Moonshot AI) configuration
#[serde(default)]
pub kimi: KimiConfig,
/// DeepSeek configuration
#[serde(default)]
pub deepseek: DeepSeekConfig,
/// OpenRouter configuration
#[serde(default)]
pub openrouter: OpenRouterConfig,
/// Custom API configuration
#[serde(default)]
pub custom: Option<CustomLlmConfig>,
/// Maximum tokens for generation /// Maximum tokens for generation
#[serde(default = "default_max_tokens")] #[serde(default = "default_max_tokens")]
@@ -123,186 +99,45 @@ pub struct LlmConfig {
/// Timeout in seconds /// Timeout in seconds
#[serde(default = "default_timeout")] #[serde(default = "default_timeout")]
pub timeout: u64, pub timeout: u64,
/// API key storage method (keyring, config, environment)
#[serde(default = "default_api_key_storage")]
pub api_key_storage: String,
/// API key (stored in config for fallback, encrypted if encrypt_sensitive is true)
#[serde(default)]
pub api_key: Option<String>,
/// Enable thinking/reasoning mode (deepseek, kimi, anthropic)
#[serde(default)]
pub thinking_enabled: bool,
/// Budget tokens for thinking mode (Anthropic Claude 4)
#[serde(default)]
pub thinking_budget_tokens: Option<u32>,
}
fn default_api_key_storage() -> String {
"keyring".to_string()
} }
impl Default for LlmConfig { impl Default for LlmConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
provider: default_llm_provider(), provider: default_llm_provider(),
openai: OpenAiConfig::default(), model: default_model(),
ollama: OllamaConfig::default(), base_url: None,
anthropic: AnthropicConfig::default(),
kimi: KimiConfig::default(),
deepseek: DeepSeekConfig::default(),
openrouter: OpenRouterConfig::default(),
custom: None,
max_tokens: default_max_tokens(), max_tokens: default_max_tokens(),
temperature: default_temperature(), temperature: default_temperature(),
timeout: default_timeout(), timeout: default_timeout(),
} api_key_storage: default_api_key_storage(),
}
}
/// OpenAI API configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAiConfig {
/// API key
pub api_key: Option<String>,
/// Model to use
#[serde(default = "default_openai_model")]
pub model: String,
/// API base URL (for custom endpoints)
#[serde(default = "default_openai_base_url")]
pub base_url: String,
}
impl Default for OpenAiConfig {
fn default() -> Self {
Self {
api_key: None, api_key: None,
model: default_openai_model(), thinking_enabled: false,
base_url: default_openai_base_url(), thinking_budget_tokens: None,
} }
} }
} }
/// Ollama configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaConfig {
/// Ollama server URL
#[serde(default = "default_ollama_url")]
pub url: String,
/// Model to use
#[serde(default = "default_ollama_model")]
pub model: String,
}
impl Default for OllamaConfig {
fn default() -> Self {
Self {
url: default_ollama_url(),
model: default_ollama_model(),
}
}
}
/// Anthropic Claude configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnthropicConfig {
/// API key
pub api_key: Option<String>,
/// Model to use
#[serde(default = "default_anthropic_model")]
pub model: String,
}
impl Default for AnthropicConfig {
fn default() -> Self {
Self {
api_key: None,
model: default_anthropic_model(),
}
}
}
/// Kimi (Moonshot AI) configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KimiConfig {
/// API key
pub api_key: Option<String>,
/// Model to use
#[serde(default = "default_kimi_model")]
pub model: String,
/// API base URL (for custom endpoints)
#[serde(default = "default_kimi_base_url")]
pub base_url: String,
}
impl Default for KimiConfig {
fn default() -> Self {
Self {
api_key: None,
model: default_kimi_model(),
base_url: default_kimi_base_url(),
}
}
}
/// DeepSeek configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeepSeekConfig {
/// API key
pub api_key: Option<String>,
/// Model to use
#[serde(default = "default_deepseek_model")]
pub model: String,
/// API base URL (for custom endpoints)
#[serde(default = "default_deepseek_base_url")]
pub base_url: String,
}
impl Default for DeepSeekConfig {
fn default() -> Self {
Self {
api_key: None,
model: default_deepseek_model(),
base_url: default_deepseek_base_url(),
}
}
}
/// OpenRouter configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenRouterConfig {
/// API key
pub api_key: Option<String>,
/// Model to use
#[serde(default = "default_openrouter_model")]
pub model: String,
/// API base URL (for custom endpoints)
#[serde(default = "default_openrouter_base_url")]
pub base_url: String,
}
impl Default for OpenRouterConfig {
fn default() -> Self {
Self {
api_key: None,
model: default_openrouter_model(),
base_url: default_openrouter_base_url(),
}
}
}
/// Custom LLM API configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CustomLlmConfig {
/// API endpoint URL
pub url: String,
/// API key (optional)
pub api_key: Option<String>,
/// Model name
pub model: String,
/// Request format template (JSON)
pub request_template: String,
/// Response path to extract content (e.g., "choices.0.message.content")
pub response_path: String,
}
/// Commit configuration /// Commit configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommitConfig { pub struct CommitConfig {
@@ -592,6 +427,10 @@ fn default_llm_provider() -> String {
"ollama".to_string() "ollama".to_string()
} }
fn default_model() -> String {
"llama2".to_string()
}
fn default_max_tokens() -> u32 { fn default_max_tokens() -> u32 {
500 500
} }
@@ -604,50 +443,6 @@ fn default_timeout() -> u64 {
30 30
} }
fn default_openai_model() -> String {
"gpt-4".to_string()
}
fn default_openai_base_url() -> String {
"https://api.openai.com/v1".to_string()
}
fn default_ollama_url() -> String {
"http://localhost:11434".to_string()
}
fn default_ollama_model() -> String {
"llama2".to_string()
}
fn default_anthropic_model() -> String {
"claude-3-sonnet-20240229".to_string()
}
fn default_kimi_model() -> String {
"moonshot-v1-8k".to_string()
}
fn default_kimi_base_url() -> String {
"https://api.moonshot.cn/v1".to_string()
}
fn default_deepseek_model() -> String {
"deepseek-chat".to_string()
}
fn default_deepseek_base_url() -> String {
"https://api.deepseek.com/v1".to_string()
}
fn default_openrouter_model() -> String {
"openai/gpt-3.5-turbo".to_string()
}
fn default_openrouter_base_url() -> String {
"https://openrouter.ai/api/v1".to_string()
}
fn default_commit_format() -> CommitFormat { fn default_commit_format() -> CommitFormat {
CommitFormat::Conventional CommitFormat::Conventional
} }
@@ -696,39 +491,89 @@ impl AppConfig {
/// Save configuration to file /// Save configuration to file
pub fn save(&self, path: &Path) -> Result<()> { pub fn save(&self, path: &Path) -> Result<()> {
let content = toml::to_string_pretty(self) let content = toml::to_string_pretty(self).context("Failed to serialize config")?;
.context("Failed to serialize config")?;
if let Some(parent) = path.parent() { if let Some(parent) = path.parent() {
fs::create_dir_all(parent) fs::create_dir_all(parent)
.with_context(|| format!("Failed to create config directory: {:?}", parent))?; .with_context(|| format!("Failed to create config directory: {:?}", parent))?;
} }
fs::write(path, content) fs::write(path, content)
.with_context(|| format!("Failed to write config file: {:?}", path))?; .with_context(|| format!("Failed to write config file: {:?}", path))?;
Ok(()) Ok(())
} }
/// Get default config path /// Get default config path
pub fn default_path() -> Result<PathBuf> { pub fn default_path() -> Result<PathBuf> {
let config_dir = dirs::config_dir() let config_dir = dirs::config_dir().context("Could not find config directory")?;
.context("Could not find config directory")?;
Ok(config_dir.join("quicommit").join("config.toml")) Ok(config_dir.join("quicommit").join("config.toml"))
} }
/// Get profile for a repository // /// Get profile for a repository
pub fn get_profile_for_repo(&self, repo_path: &str) -> Option<&GitProfile> { // pub fn get_profile_for_repo(&self, repo_path: &str) -> Option<&GitProfile> {
let profile_name = self.repo_profiles.get(repo_path)?; // let profile_name = self.repo_profiles.get(repo_path)?;
self.profiles.get(profile_name) // self.profiles.get(profile_name)
// }
// /// Set profile for a repository
// pub fn set_profile_for_repo(&mut self, repo_path: String, profile_name: String) -> Result<()> {
// if !self.profiles.contains_key(&profile_name) {
// anyhow::bail!("Profile '{}' does not exist", profile_name);
// }
// self.repo_profiles.insert(repo_path, profile_name);
// Ok(())
// }
}
/// Encrypted PAT data for export
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedPat {
/// Profile name
pub profile_name: String,
/// Service name (e.g., github, gitlab)
pub service: String,
/// User email (for keyring lookup)
pub user_email: String,
/// Encrypted token value
pub encrypted_token: String,
}
/// Export data container with optional encrypted PATs
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExportData {
/// Configuration content (TOML string)
pub config: String,
/// Encrypted PATs (only present when exporting with encryption)
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub encrypted_pats: Vec<EncryptedPat>,
/// Export version for future compatibility
#[serde(default = "default_export_version")]
pub export_version: String,
}
fn default_export_version() -> String {
"1".to_string()
}
impl ExportData {
pub fn new(config: String) -> Self {
Self {
config,
encrypted_pats: Vec::new(),
export_version: default_export_version(),
}
} }
/// Set profile for a repository pub fn with_encrypted_pats(config: String, pats: Vec<EncryptedPat>) -> Self {
pub fn set_profile_for_repo(&mut self, repo_path: String, profile_name: String) -> Result<()> { Self {
if !self.profiles.contains_key(&profile_name) { config,
anyhow::bail!("Profile '{}' does not exist", profile_name); encrypted_pats: pats,
export_version: default_export_version(),
} }
self.repo_profiles.insert(repo_path, profile_name); }
Ok(())
pub fn has_encrypted_pats(&self) -> bool {
!self.encrypted_pats.is_empty()
} }
} }

View File

@@ -1,4 +1,4 @@
use anyhow::{bail, Result}; use anyhow::{Result, bail};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
@@ -80,25 +80,25 @@ impl GitProfile {
if self.user_name.is_empty() { if self.user_name.is_empty() {
bail!("User name cannot be empty"); bail!("User name cannot be empty");
} }
if self.user_email.is_empty() { if self.user_email.is_empty() {
bail!("User email cannot be empty"); bail!("User email cannot be empty");
} }
crate::utils::validators::validate_email(&self.user_email)?; crate::utils::validators::validate_email(&self.user_email)?;
if let Some(ref ssh) = self.ssh { if let Some(ref ssh) = self.ssh {
ssh.validate()?; ssh.validate()?;
} }
if let Some(ref gpg) = self.gpg { if let Some(ref gpg) = self.gpg {
gpg.validate()?; gpg.validate()?;
} }
for token in self.tokens.values() { for token in self.tokens.values() {
token.validate()?; token.validate()?;
} }
Ok(()) Ok(())
} }
@@ -120,8 +120,7 @@ impl GitProfile {
/// Get signing key (from GPG config or direct) /// Get signing key (from GPG config or direct)
pub fn signing_key(&self) -> Option<&str> { pub fn signing_key(&self) -> Option<&str> {
self.signing_key self.signing_key
.as_ref() .as_deref()
.map(|s| s.as_str())
.or_else(|| self.gpg.as_ref().map(|g| g.key_id.as_str())) .or_else(|| self.gpg.as_ref().map(|g| g.key_id.as_str()))
} }
@@ -144,7 +143,7 @@ impl GitProfile {
pub fn record_usage(&mut self, repo_path: Option<String>) { pub fn record_usage(&mut self, repo_path: Option<String>) {
self.usage.last_used = Some(chrono::Utc::now().to_rfc3339()); self.usage.last_used = Some(chrono::Utc::now().to_rfc3339());
self.usage.total_uses += 1; self.usage.total_uses += 1;
if let Some(repo) = repo_path { if let Some(repo) = repo_path {
let count = self.usage.repo_usage.entry(repo).or_insert(0); let count = self.usage.repo_usage.entry(repo).or_insert(0);
*count += 1; *count += 1;
@@ -159,93 +158,95 @@ impl GitProfile {
/// Apply this profile to a git repository (local config) /// Apply this profile to a git repository (local config)
pub fn apply_to_repo(&self, repo: &git2::Repository) -> Result<()> { pub fn apply_to_repo(&self, repo: &git2::Repository) -> Result<()> {
let mut config = repo.config()?; let mut config = repo.config()?;
config.set_str("user.name", &self.user_name)?; config.set_str("user.name", &self.user_name)?;
config.set_str("user.email", &self.user_email)?; config.set_str("user.email", &self.user_email)?;
if let Some(key) = self.signing_key() { if let Some(key) = self.signing_key() {
config.set_str("user.signingkey", key)?; config.set_str("user.signingkey", key)?;
if self.settings.auto_sign_commits { if self.settings.auto_sign_commits {
config.set_bool("commit.gpgsign", true)?; config.set_bool("commit.gpgsign", true)?;
} }
if self.settings.auto_sign_tags { if self.settings.auto_sign_tags {
config.set_bool("tag.gpgsign", true)?; config.set_bool("tag.gpgsign", true)?;
} }
} }
if let Some(ref ssh) = self.ssh { if let Some(ref ssh) = self.ssh
if let Some(ref key_path) = ssh.private_key_path { && let Some(ref key_path) = ssh.private_key_path
let path_str = key_path.display().to_string(); {
#[cfg(target_os = "windows")] let path_str = key_path.display().to_string();
{ #[cfg(target_os = "windows")]
config.set_str("core.sshCommand", {
&format!("ssh -i \"{}\"", path_str.replace('\\', "/")))?; config.set_str(
} "core.sshCommand",
#[cfg(not(target_os = "windows"))] &format!("ssh -i \"{}\"", path_str.replace('\\', "/")),
{ )?;
config.set_str("core.sshCommand", }
&format!("ssh -i '{}'", path_str))?; #[cfg(not(target_os = "windows"))]
} {
config.set_str("core.sshCommand", &format!("ssh -i '{}'", path_str))?;
} }
} }
Ok(()) Ok(())
} }
/// Apply this profile globally /// Apply this profile globally
pub fn apply_global(&self) -> Result<()> { pub fn apply_global(&self) -> Result<()> {
let mut config = git2::Config::open_default()?; let mut config = git2::Config::open_default()?;
config.set_str("user.name", &self.user_name)?; config.set_str("user.name", &self.user_name)?;
config.set_str("user.email", &self.user_email)?; config.set_str("user.email", &self.user_email)?;
if let Some(key) = self.signing_key() { if let Some(key) = self.signing_key() {
config.set_str("user.signingkey", key)?; config.set_str("user.signingkey", key)?;
if self.settings.auto_sign_commits { if self.settings.auto_sign_commits {
config.set_bool("commit.gpgsign", true)?; config.set_bool("commit.gpgsign", true)?;
} }
if self.settings.auto_sign_tags { if self.settings.auto_sign_tags {
config.set_bool("tag.gpgsign", true)?; config.set_bool("tag.gpgsign", true)?;
} }
} }
if let Some(ref ssh) = self.ssh { if let Some(ref ssh) = self.ssh
if let Some(ref key_path) = ssh.private_key_path { && let Some(ref key_path) = ssh.private_key_path
let path_str = key_path.display().to_string(); {
#[cfg(target_os = "windows")] let path_str = key_path.display().to_string();
{ #[cfg(target_os = "windows")]
config.set_str("core.sshCommand", {
&format!("ssh -i \"{}\"", path_str.replace('\\', "/")))?; config.set_str(
} "core.sshCommand",
#[cfg(not(target_os = "windows"))] &format!("ssh -i \"{}\"", path_str.replace('\\', "/")),
{ )?;
config.set_str("core.sshCommand", }
&format!("ssh -i '{}'", path_str))?; #[cfg(not(target_os = "windows"))]
} {
config.set_str("core.sshCommand", &format!("ssh -i '{}'", path_str))?;
} }
} }
Ok(()) Ok(())
} }
/// Compare with current git configuration /// Compare with current git configuration
pub fn compare_with_git_config(&self, repo: &git2::Repository) -> Result<ProfileComparison> { pub fn compare_with_git_config(&self, repo: &git2::Repository) -> Result<ProfileComparison> {
let config = repo.config()?; let config = repo.config()?;
let git_user_name = config.get_string("user.name").ok(); let git_user_name = config.get_string("user.name").ok();
let git_user_email = config.get_string("user.email").ok(); let git_user_email = config.get_string("user.email").ok();
let git_signing_key = config.get_string("user.signingkey").ok(); let git_signing_key = config.get_string("user.signingkey").ok();
let mut comparison = ProfileComparison { let mut comparison = ProfileComparison {
profile_name: self.name.clone(), profile_name: self.name.clone(),
matches: true, matches: true,
differences: vec![], differences: vec![],
}; };
if git_user_name.as_deref() != Some(&self.user_name) { if git_user_name.as_deref() != Some(&self.user_name) {
comparison.matches = false; comparison.matches = false;
comparison.differences.push(ConfigDifference { comparison.differences.push(ConfigDifference {
@@ -254,7 +255,7 @@ impl GitProfile {
git_value: git_user_name.unwrap_or_else(|| "<not set>".to_string()), git_value: git_user_name.unwrap_or_else(|| "<not set>".to_string()),
}); });
} }
if git_user_email.as_deref() != Some(&self.user_email) { if git_user_email.as_deref() != Some(&self.user_email) {
comparison.matches = false; comparison.matches = false;
comparison.differences.push(ConfigDifference { comparison.differences.push(ConfigDifference {
@@ -263,24 +264,24 @@ impl GitProfile {
git_value: git_user_email.unwrap_or_else(|| "<not set>".to_string()), git_value: git_user_email.unwrap_or_else(|| "<not set>".to_string()),
}); });
} }
if let Some(profile_key) = self.signing_key() { if let Some(profile_key) = self.signing_key()
if git_signing_key.as_deref() != Some(profile_key) { && git_signing_key.as_deref() != Some(profile_key)
comparison.matches = false; {
comparison.differences.push(ConfigDifference { comparison.matches = false;
key: "user.signingkey".to_string(), comparison.differences.push(ConfigDifference {
profile_value: profile_key.to_string(), key: "user.signingkey".to_string(),
git_value: git_signing_key.unwrap_or_else(|| "<not set>".to_string()), profile_value: profile_key.to_string(),
}); git_value: git_signing_key.unwrap_or_else(|| "<not set>".to_string()),
} });
} }
Ok(comparison) Ok(comparison)
} }
} }
/// Profile settings /// Profile settings
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ProfileSettings { pub struct ProfileSettings {
/// Automatically sign commits /// Automatically sign commits
#[serde(default)] #[serde(default)]
@@ -307,19 +308,6 @@ pub struct ProfileSettings {
pub commit_template: Option<String>, pub commit_template: Option<String>,
} }
impl Default for ProfileSettings {
fn default() -> Self {
Self {
auto_sign_commits: false,
auto_sign_tags: false,
default_commit_format: None,
repo_patterns: vec![],
llm_provider: None,
commit_template: None,
}
}
}
/// SSH configuration /// SSH configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SshConfig { pub struct SshConfig {
@@ -349,18 +337,18 @@ pub struct SshConfig {
impl SshConfig { impl SshConfig {
/// Validate SSH configuration /// Validate SSH configuration
pub fn validate(&self) -> Result<()> { pub fn validate(&self) -> Result<()> {
if let Some(ref path) = self.private_key_path { if let Some(ref path) = self.private_key_path
if !path.exists() { && !path.exists()
bail!("SSH private key does not exist: {:?}", path); {
} bail!("SSH private key does not exist: {:?}", path);
} }
if let Some(ref path) = self.public_key_path { if let Some(ref path) = self.public_key_path
if !path.exists() { && !path.exists()
bail!("SSH public key does not exist: {:?}", path); {
} bail!("SSH public key does not exist: {:?}", path);
} }
Ok(()) Ok(())
} }
@@ -423,10 +411,6 @@ impl GpgConfig {
/// Token configuration for services (GitHub, GitLab, etc.) /// Token configuration for services (GitHub, GitLab, etc.)
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenConfig { pub struct TokenConfig {
/// Token value (encrypted)
#[serde(skip_serializing_if = "Option::is_none")]
pub token: Option<String>,
/// Token type (personal, oauth, etc.) /// Token type (personal, oauth, etc.)
#[serde(default)] #[serde(default)]
pub token_type: TokenType, pub token_type: TokenType,
@@ -446,25 +430,41 @@ pub struct TokenConfig {
/// Description /// Description
#[serde(default)] #[serde(default)]
pub description: Option<String>, pub description: Option<String>,
/// Indicates if a token is stored in keyring
#[serde(default)]
pub has_token: bool,
} }
impl TokenConfig { impl TokenConfig {
/// Create a new token config /// Create a new token config (token stored separately in keyring)
pub fn new(token: String, token_type: TokenType) -> Self { pub fn new(token_type: TokenType) -> Self {
Self { Self {
token: Some(token),
token_type, token_type,
scopes: vec![], scopes: vec![],
expires_at: None, expires_at: None,
last_used: None, last_used: None,
description: None, description: None,
has_token: true,
}
}
/// Create a new token config without token
pub fn without_token(token_type: TokenType) -> Self {
Self {
token_type,
scopes: vec![],
expires_at: None,
last_used: None,
description: None,
has_token: false,
} }
} }
/// Validate token configuration /// Validate token configuration
pub fn validate(&self) -> Result<()> { pub fn validate(&self) -> Result<()> {
if self.token.is_none() && self.token_type != TokenType::None { if !self.has_token && self.token_type != TokenType::None {
bail!("Token value is required for {:?}", self.token_type); bail!("Token is required for {:?}", self.token_type);
} }
Ok(()) Ok(())
} }
@@ -473,12 +473,19 @@ impl TokenConfig {
pub fn record_usage(&mut self) { pub fn record_usage(&mut self) {
self.last_used = Some(chrono::Utc::now().to_rfc3339()); self.last_used = Some(chrono::Utc::now().to_rfc3339());
} }
/// Mark that a token is stored
pub fn set_has_token(&mut self, has_token: bool) {
self.has_token = has_token;
}
} }
/// Token type /// Token type
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum TokenType { pub enum TokenType {
#[default]
None, None,
Personal, Personal,
OAuth, OAuth,
@@ -486,12 +493,6 @@ pub enum TokenType {
App, App,
} }
impl Default for TokenType {
fn default() -> Self {
Self::None
}
}
impl std::fmt::Display for TokenType { impl std::fmt::Display for TokenType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
@@ -621,9 +622,15 @@ impl GitProfileBuilder {
} }
pub fn build(self) -> Result<GitProfile> { pub fn build(self) -> Result<GitProfile> {
let name = self.name.ok_or_else(|| anyhow::anyhow!("Name is required"))?; let name = self
let user_name = self.user_name.ok_or_else(|| anyhow::anyhow!("User name is required"))?; .name
let user_email = self.user_email.ok_or_else(|| anyhow::anyhow!("User email is required"))?; .ok_or_else(|| anyhow::anyhow!("Name is required"))?;
let user_name = self
.user_name
.ok_or_else(|| anyhow::anyhow!("User name is required"))?;
let user_email = self
.user_email
.ok_or_else(|| anyhow::anyhow!("User email is required"))?;
Ok(GitProfile { Ok(GitProfile {
name, name,
@@ -669,13 +676,13 @@ mod tests {
"".to_string(), "".to_string(),
"invalid-email".to_string(), "invalid-email".to_string(),
); );
assert!(profile.validate().is_err()); assert!(profile.validate().is_err());
} }
#[test] #[test]
fn test_token_config() { fn test_token_config() {
let token = TokenConfig::new("test-token".to_string(), TokenType::Personal); let token = TokenConfig::new(TokenType::Personal);
assert!(token.validate().is_ok()); assert!(token.validate().is_ok());
} }
} }

View File

@@ -1,4 +1,5 @@
use crate::config::{CommitFormat, LlmConfig, Language}; use crate::config::manager::ConfigManager;
use crate::config::{CommitFormat, Language};
use crate::git::{CommitInfo, GitRepo}; use crate::git::{CommitInfo, GitRepo};
use crate::llm::{GeneratedCommit, LlmClient}; use crate::llm::{GeneratedCommit, LlmClient};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
@@ -10,17 +11,44 @@ pub struct ContentGenerator {
impl ContentGenerator { impl ContentGenerator {
/// Create new content generator /// Create new content generator
pub async fn new(config: &LlmConfig) -> Result<Self> { pub async fn new(manager: &ConfigManager) -> Result<Self> {
let llm_client = LlmClient::from_config(config).await?; Self::new_with_think(manager, false).await
}
// Check if provider is available
if !llm_client.is_available().await { /// Create new content generator with thinking override
anyhow::bail!("LLM provider '{}' is not available", config.provider); pub async fn new_with_think(manager: &ConfigManager, think_override: bool) -> Result<Self> {
let mut thinking_enabled = if think_override {
true
} else {
manager.config().llm.thinking_enabled
};
// Validate thinking support per provider
if thinking_enabled {
let provider = manager.llm_provider();
if !Self::supports_thinking(provider) {
eprintln!(
"Warning: Provider '{}' does not support thinking mode. \
Disabling thinking for this invocation.",
provider
);
thinking_enabled = false;
}
} }
let llm_client = LlmClient::from_config_with_think(manager, thinking_enabled).await?;
if !llm_client.is_available().await {
anyhow::bail!("LLM provider '{}' is not available", manager.llm_provider());
}
Ok(Self { llm_client }) Ok(Self { llm_client })
} }
fn supports_thinking(provider: &str) -> bool {
matches!(provider, "deepseek" | "kimi" | "anthropic" | "openai")
}
/// Generate commit message from diff /// Generate commit message from diff
pub async fn generate_commit_message( pub async fn generate_commit_message(
&self, &self,
@@ -31,12 +59,15 @@ impl ContentGenerator {
// Truncate diff if too long // Truncate diff if too long
let max_diff_len = 4000; let max_diff_len = 4000;
let truncated_diff = if diff.len() > max_diff_len { let truncated_diff = if diff.len() > max_diff_len {
format!("{}\n... (truncated)", &diff[..max_diff_len]) let boundary = diff.floor_char_boundary(max_diff_len);
format!("{}\n... (truncated)", &diff[..boundary])
} else { } else {
diff.to_string() diff.to_string()
}; };
self.llm_client.generate_commit_message(&truncated_diff, format, language).await self.llm_client
.generate_commit_message(&truncated_diff, format, language)
.await
} }
/// Generate commit message from repository changes /// Generate commit message from repository changes
@@ -46,13 +77,14 @@ impl ContentGenerator {
format: CommitFormat, format: CommitFormat,
language: Language, language: Language,
) -> Result<GeneratedCommit> { ) -> Result<GeneratedCommit> {
let diff = repo.get_staged_diff() let diff = repo
.get_staged_diff_sorted()
.context("Failed to get staged diff")?; .context("Failed to get staged diff")?;
if diff.is_empty() { if diff.is_empty() {
anyhow::bail!("No staged changes to generate commit from"); anyhow::bail!("No staged changes to generate commit from");
} }
self.generate_commit_message(&diff, format, language).await self.generate_commit_message(&diff, format, language).await
} }
@@ -63,12 +95,12 @@ impl ContentGenerator {
commits: &[CommitInfo], commits: &[CommitInfo],
language: Language, language: Language,
) -> Result<String> { ) -> Result<String> {
let commit_messages: Vec<String> = commits let commit_messages: Vec<String> =
.iter() commits.iter().map(|c| c.subject().to_string()).collect();
.map(|c| c.subject().to_string())
.collect(); self.llm_client
.generate_tag_message(version, &commit_messages, language)
self.llm_client.generate_tag_message(version, &commit_messages, language).await .await
} }
/// Generate changelog entry /// Generate changelog entry
@@ -85,8 +117,10 @@ impl ContentGenerator {
(commit_type, c.subject().to_string()) (commit_type, c.subject().to_string())
}) })
.collect(); .collect();
self.llm_client.generate_changelog_entry(version, &typed_commits, language).await self.llm_client
.generate_changelog_entry(version, &typed_commits, language)
.await
} }
/// Generate changelog from repository /// Generate changelog from repository
@@ -102,8 +136,9 @@ impl ContentGenerator {
} else { } else {
repo.get_commits(50)? repo.get_commits(50)?
}; };
self.generate_changelog_entry(version, &commits, language).await self.generate_changelog_entry(version, &commits, language)
.await
} }
/// Interactive commit generation with user feedback /// Interactive commit generation with user feedback
@@ -114,49 +149,53 @@ impl ContentGenerator {
language: Language, language: Language,
) -> Result<GeneratedCommit> { ) -> Result<GeneratedCommit> {
use dialoguer::Select; use dialoguer::Select;
let diff = repo.get_staged_diff()?; let diff = repo.get_staged_diff_sorted()?;
if diff.is_empty() { if diff.is_empty() {
anyhow::bail!("No staged changes"); anyhow::bail!("No staged changes");
} }
// Show diff summary // Show diff summary
let files = repo.get_staged_files()?; let files = repo.get_staged_files()?;
println!("\nStaged files ({}):", files.len()); println!("\nStaged files ({}):", files.len());
for file in &files { for file in &files {
println!("{}", file); println!("{}", file);
} }
// Generate initial commit // Generate initial commit
println!("\nGenerating commit message..."); println!("\nGenerating commit message...");
let mut generated = self.generate_commit_message(&diff, format, language).await?; let mut generated = self
.generate_commit_message(&diff, format, language)
.await?;
loop { loop {
println!("\n{}", "".repeat(60)); println!("\n{}", "".repeat(60));
println!("Generated commit message:"); println!("Generated commit message:");
println!("{}", "".repeat(60)); println!("{}", "".repeat(60));
println!("{}", generated.to_conventional()); println!("{}", generated.to_conventional());
println!("{}", "".repeat(60)); println!("{}", "".repeat(60));
let options = vec![ let options = vec![
"✓ Accept and commit", "✓ Accept and commit",
"🔄 Regenerate", "🔄 Regenerate",
"✏️ Edit", "✏️ Edit",
"❌ Cancel", "❌ Cancel",
]; ];
let selection = Select::new() let selection = Select::new()
.with_prompt("What would you like to do?") .with_prompt("What would you like to do?")
.items(&options) .items(&options)
.default(0) .default(0)
.interact()?; .interact()?;
match selection { match selection {
0 => return Ok(generated), 0 => return Ok(generated),
1 => { 1 => {
println!("Regenerating..."); println!("Regenerating...");
generated = self.generate_commit_message(&diff, format, language).await?; generated = self
.generate_commit_message(&diff, format, language)
.await?;
} }
2 => { 2 => {
let edited = crate::utils::editor::edit_content(&generated.to_conventional())?; let edited = crate::utils::editor::edit_content(&generated.to_conventional())?;
@@ -170,7 +209,7 @@ impl ContentGenerator {
fn parse_edited_commit(&self, edited: &str, _format: CommitFormat) -> Result<GeneratedCommit> { fn parse_edited_commit(&self, edited: &str, _format: CommitFormat) -> Result<GeneratedCommit> {
let parsed = crate::git::commit::parse_commit_message(edited); let parsed = crate::git::commit::parse_commit_message(edited);
Ok(GeneratedCommit { Ok(GeneratedCommit {
commit_type: parsed.commit_type.unwrap_or_else(|| "chore".to_string()), commit_type: parsed.commit_type.unwrap_or_else(|| "chore".to_string()),
scope: parsed.scope, scope: parsed.scope,
@@ -207,11 +246,15 @@ pub mod fallback {
let has_code = files.iter().any(|f| { let has_code = files.iter().any(|f| {
f.ends_with(".rs") || f.ends_with(".py") || f.ends_with(".js") || f.ends_with(".ts") f.ends_with(".rs") || f.ends_with(".py") || f.ends_with(".js") || f.ends_with(".ts")
}); });
let has_docs = files.iter().any(|f| f.ends_with(".md") || f.contains("README")); let has_docs = files
.iter()
let has_tests = files.iter().any(|f| f.contains("test") || f.contains("spec")); .any(|f| f.ends_with(".md") || f.contains("README"));
let has_tests = files
.iter()
.any(|f| f.contains("test") || f.contains("spec"));
if has_tests { if has_tests {
"test: update tests".to_string() "test: update tests".to_string()
} else if has_docs { } else if has_docs {

View File

@@ -95,9 +95,7 @@ impl ChangelogGenerator {
ChangelogFormat::GitHubReleases => { ChangelogFormat::GitHubReleases => {
self.generate_github_releases(version, date, commits) self.generate_github_releases(version, date, commits)
} }
ChangelogFormat::Custom => { ChangelogFormat::Custom => self.generate_custom(version, date, commits),
self.generate_custom(version, date, commits)
}
} }
} }
@@ -110,13 +108,13 @@ impl ChangelogGenerator {
commits: &[CommitInfo], commits: &[CommitInfo],
) -> Result<()> { ) -> Result<()> {
let entry = self.generate(version, date, commits)?; let entry = self.generate(version, date, commits)?;
let existing = if changelog_path.exists() { let existing = if changelog_path.exists() {
fs::read_to_string(changelog_path)? fs::read_to_string(changelog_path)?
} else { } else {
String::new() String::new()
}; };
let new_content = if existing.is_empty() { let new_content = if existing.is_empty() {
format!("{}{}", CHANGELOG_HEADER, entry) format!("{}{}", CHANGELOG_HEADER, entry)
} else if existing.starts_with(CHANGELOG_HEADER) { } else if existing.starts_with(CHANGELOG_HEADER) {
@@ -124,7 +122,7 @@ impl ChangelogGenerator {
} else if existing.starts_with("# Changelog") { } else if existing.starts_with("# Changelog") {
let lines: Vec<&str> = existing.lines().collect(); let lines: Vec<&str> = existing.lines().collect();
let mut header_end = 0; let mut header_end = 0;
for (i, line) in lines.iter().enumerate() { for (i, line) in lines.iter().enumerate() {
if i == 0 && line.starts_with('#') { if i == 0 && line.starts_with('#') {
header_end = i + 1; header_end = i + 1;
@@ -134,18 +132,18 @@ impl ChangelogGenerator {
break; break;
} }
} }
let header = lines[..header_end].join("\n"); let header = lines[..header_end].join("\n");
let rest = lines[header_end..].join("\n"); let rest = lines[header_end..].join("\n");
format!("{}\n{}\n{}", header, entry, rest) format!("{}\n{}\n{}", header, entry, rest)
} else { } else {
format!("{}{}", CHANGELOG_HEADER, entry) format!("{}{}", CHANGELOG_HEADER, entry)
}; };
fs::write(changelog_path, new_content) fs::write(changelog_path, new_content)
.with_context(|| format!("Failed to write changelog: {:?}", changelog_path))?; .with_context(|| format!("Failed to write changelog: {:?}", changelog_path))?;
Ok(()) Ok(())
} }
@@ -157,10 +155,10 @@ impl ChangelogGenerator {
) -> Result<String> { ) -> Result<String> {
let date_str = date.format("%Y-%m-%d").to_string(); let date_str = date.format("%Y-%m-%d").to_string();
let mut output = format!("## [{}] - {}\n\n", version, date_str); let mut output = format!("## [{}] - {}\n\n", version, date_str);
if self.group_by_type { if self.group_by_type {
let grouped = self.group_commits(commits); let _grouped = self.group_commits(commits);
// Standard categories // Standard categories
let categories = vec![ let categories = vec![
("Added", vec!["feat"]), ("Added", vec!["feat"]),
@@ -170,7 +168,7 @@ impl ChangelogGenerator {
("Fixed", vec!["fix"]), ("Fixed", vec!["fix"]),
("Security", vec!["security"]), ("Security", vec!["security"]),
]; ];
for (title, types) in &categories { for (title, types) in &categories {
let items: Vec<&CommitInfo> = commits let items: Vec<&CommitInfo> = commits
.iter() .iter()
@@ -182,7 +180,7 @@ impl ChangelogGenerator {
} }
}) })
.collect(); .collect();
if !items.is_empty() { if !items.is_empty() {
output.push_str(&format!("### {}\n\n", title)); output.push_str(&format!("### {}\n\n", title));
for commit in items { for commit in items {
@@ -192,13 +190,13 @@ impl ChangelogGenerator {
output.push('\n'); output.push('\n');
} }
} }
// Other changes // Other changes
let categorized: Vec<String> = categories let categorized: Vec<String> = categories
.iter() .iter()
.flat_map(|(_, types)| types.iter().map(|s| s.to_string())) .flat_map(|(_, types)| types.iter().map(|s| s.to_string()))
.collect(); .collect();
let other: Vec<&CommitInfo> = commits let other: Vec<&CommitInfo> = commits
.iter() .iter()
.filter(|c| { .filter(|c| {
@@ -209,7 +207,7 @@ impl ChangelogGenerator {
} }
}) })
.collect(); .collect();
if !other.is_empty() { if !other.is_empty() {
output.push_str("### Other\n\n"); output.push_str("### Other\n\n");
for commit in other { for commit in other {
@@ -224,30 +222,30 @@ impl ChangelogGenerator {
output.push('\n'); output.push('\n');
} }
} }
Ok(output) Ok(output)
} }
fn generate_github_releases( fn generate_github_releases(
&self, &self,
version: &str, _version: &str,
_date: DateTime<Utc>, _date: DateTime<Utc>,
commits: &[CommitInfo], commits: &[CommitInfo],
) -> Result<String> { ) -> Result<String> {
let mut output = format!("## What's Changed\n\n"); let mut output = "## What's Changed\n\n".to_string();
// Group by type // Group by type
let mut features = vec![]; let mut features = vec![];
let mut fixes = vec![]; let mut fixes = vec![];
let mut docs = vec![]; let mut docs = vec![];
let mut other = vec![]; let mut other = vec![];
let mut breaking = vec![]; let mut breaking = vec![];
for commit in commits { for commit in commits {
if commit.message.contains("BREAKING CHANGE") { if commit.message.contains("BREAKING CHANGE") {
breaking.push(commit); breaking.push(commit);
} }
if let Some(ref t) = commit.commit_type() { if let Some(ref t) = commit.commit_type() {
match t.as_str() { match t.as_str() {
"feat" => features.push(commit), "feat" => features.push(commit),
@@ -259,7 +257,7 @@ impl ChangelogGenerator {
other.push(commit); other.push(commit);
} }
} }
if !breaking.is_empty() { if !breaking.is_empty() {
output.push_str("### ⚠ Breaking Changes\n\n"); output.push_str("### ⚠ Breaking Changes\n\n");
for commit in breaking { for commit in breaking {
@@ -267,7 +265,7 @@ impl ChangelogGenerator {
} }
output.push('\n'); output.push('\n');
} }
if !features.is_empty() { if !features.is_empty() {
output.push_str("### 🚀 Features\n\n"); output.push_str("### 🚀 Features\n\n");
for commit in features { for commit in features {
@@ -275,7 +273,7 @@ impl ChangelogGenerator {
} }
output.push('\n'); output.push('\n');
} }
if !fixes.is_empty() { if !fixes.is_empty() {
output.push_str("### 🐛 Bug Fixes\n\n"); output.push_str("### 🐛 Bug Fixes\n\n");
for commit in fixes { for commit in fixes {
@@ -283,7 +281,7 @@ impl ChangelogGenerator {
} }
output.push('\n'); output.push('\n');
} }
if !docs.is_empty() { if !docs.is_empty() {
output.push_str("### 📚 Documentation\n\n"); output.push_str("### 📚 Documentation\n\n");
for commit in docs { for commit in docs {
@@ -291,14 +289,14 @@ impl ChangelogGenerator {
} }
output.push('\n'); output.push('\n');
} }
if !other.is_empty() { if !other.is_empty() {
output.push_str("### Other Changes\n\n"); output.push_str("### Other Changes\n\n");
for commit in other { for commit in other {
output.push_str(&self.format_commit_github(commit)); output.push_str(&self.format_commit_github(commit));
} }
} }
Ok(output) Ok(output)
} }
@@ -312,7 +310,7 @@ impl ChangelogGenerator {
if !self.custom_categories.is_empty() { if !self.custom_categories.is_empty() {
let date_str = date.format("%Y-%m-%d").to_string(); let date_str = date.format("%Y-%m-%d").to_string();
let mut output = format!("## [{}] - {}\n\n", version, date_str); let mut output = format!("## [{}] - {}\n\n", version, date_str);
for category in &self.custom_categories { for category in &self.custom_categories {
let items: Vec<&CommitInfo> = commits let items: Vec<&CommitInfo> = commits
.iter() .iter()
@@ -324,7 +322,7 @@ impl ChangelogGenerator {
} }
}) })
.collect(); .collect();
if !items.is_empty() { if !items.is_empty() {
output.push_str(&format!("### {}\n\n", category.title)); output.push_str(&format!("### {}\n\n", category.title));
for commit in items { for commit in items {
@@ -334,7 +332,7 @@ impl ChangelogGenerator {
output.push('\n'); output.push('\n');
} }
} }
Ok(output) Ok(output)
} else { } else {
// Fall back to keep-a-changelog // Fall back to keep-a-changelog
@@ -344,30 +342,35 @@ impl ChangelogGenerator {
fn format_commit(&self, commit: &CommitInfo) -> String { fn format_commit(&self, commit: &CommitInfo) -> String {
let mut line = format!("- {}", commit.subject()); let mut line = format!("- {}", commit.subject());
if self.include_hashes { if self.include_hashes {
line.push_str(&format!(" ({})", &commit.short_id)); line.push_str(&format!(" ({})", &commit.short_id));
} }
if self.include_authors { if self.include_authors {
line.push_str(&format!(" - @{}", commit.author)); line.push_str(&format!(" - @{}", commit.author));
} }
line line
} }
fn format_commit_github(&self, commit: &CommitInfo) -> String { fn format_commit_github(&self, commit: &CommitInfo) -> String {
format!("- {} by @{} in {}\n", commit.subject(), commit.author, &commit.short_id) format!(
"- {} by @{} in {}\n",
commit.subject(),
commit.author,
&commit.short_id
)
} }
fn group_commits<'a>(&self, commits: &'a [CommitInfo]) -> HashMap<String, Vec<&'a CommitInfo>> { fn group_commits<'a>(&self, commits: &'a [CommitInfo]) -> HashMap<String, Vec<&'a CommitInfo>> {
let mut groups: HashMap<String, Vec<&'a CommitInfo>> = HashMap::new(); let mut groups: HashMap<String, Vec<&'a CommitInfo>> = HashMap::new();
for commit in commits { for commit in commits {
let commit_type = commit.commit_type().unwrap_or_else(|| "other".to_string()); let commit_type = commit.commit_type().unwrap_or_else(|| "other".to_string());
groups.entry(commit_type).or_default().push(commit); groups.entry(commit_type).or_default().push(commit);
} }
groups groups
} }
} }
@@ -380,8 +383,7 @@ impl Default for ChangelogGenerator {
/// Read existing changelog /// Read existing changelog
pub fn read_changelog(path: &Path) -> Result<String> { pub fn read_changelog(path: &Path) -> Result<String> {
fs::read_to_string(path) fs::read_to_string(path).with_context(|| format!("Failed to read changelog: {:?}", path))
.with_context(|| format!("Failed to read changelog: {:?}", path))
} }
/// Initialize new changelog file /// Initialize new changelog file
@@ -389,10 +391,10 @@ pub fn init_changelog(path: &Path) -> Result<()> {
if path.exists() { if path.exists() {
anyhow::bail!("Changelog already exists at {:?}", path); anyhow::bail!("Changelog already exists at {:?}", path);
} }
fs::write(path, CHANGELOG_HEADER) fs::write(path, CHANGELOG_HEADER)
.with_context(|| format!("Failed to create changelog: {:?}", path))?; .with_context(|| format!("Failed to create changelog: {:?}", path))?;
Ok(()) Ok(())
} }
@@ -403,7 +405,7 @@ pub fn generate_from_history(
to_ref: Option<&str>, to_ref: Option<&str>,
) -> Result<Vec<CommitInfo>> { ) -> Result<Vec<CommitInfo>> {
let to_ref = to_ref.unwrap_or("HEAD"); let to_ref = to_ref.unwrap_or("HEAD");
if let Some(from) = from_tag { if let Some(from) = from_tag {
repo.get_commits_between(from, to_ref) repo.get_commits_between(from, to_ref)
} else { } else {
@@ -413,11 +415,7 @@ pub fn generate_from_history(
} }
/// Update version links in changelog /// Update version links in changelog
pub fn update_version_links( pub fn update_version_links(changelog: &str, version: &str, compare_url: &str) -> String {
changelog: &str,
version: &str,
compare_url: &str,
) -> String {
// Add version link at the end of changelog // Add version link at the end of changelog
format!("{}\n[{}]: {}\n", changelog, version, compare_url) format!("{}\n[{}]: {}\n", changelog, version, compare_url)
} }
@@ -425,30 +423,29 @@ pub fn update_version_links(
/// Parse changelog to extract versions /// Parse changelog to extract versions
pub fn parse_versions(changelog: &str) -> Vec<(String, String)> { pub fn parse_versions(changelog: &str) -> Vec<(String, String)> {
let mut versions = vec![]; let mut versions = vec![];
for line in changelog.lines() { for line in changelog.lines() {
if line.starts_with("## [") { if line.starts_with("## [")
if let Some(start) = line.find('[') { && let Some(start) = line.find('[')
if let Some(end) = line.find(']') { && let Some(end) = line.find(']')
let version = &line[start + 1..end]; {
if version != "Unreleased" { let version = &line[start + 1..end];
if let Some(date_start) = line.find(" - ") { if version != "Unreleased"
let date = &line[date_start + 3..].trim(); && let Some(date_start) = line.find(" - ")
versions.push((version.to_string(), date.to_string())); {
} let date = &line[date_start + 3..].trim();
} versions.push((version.to_string(), date.to_string()));
}
} }
} }
} }
versions versions
} }
/// Get unreleased changes /// Get unreleased changes
pub fn get_unreleased_changes(repo: &GitRepo) -> Result<Vec<CommitInfo>> { pub fn get_unreleased_changes(repo: &GitRepo) -> Result<Vec<CommitInfo>> {
let tags = repo.get_tags()?; let tags = repo.get_tags()?;
if let Some(latest_tag) = tags.first() { if let Some(latest_tag) = tags.first() {
repo.get_commits_between(&latest_tag.name, "HEAD") repo.get_commits_between(&latest_tag.name, "HEAD")
} else { } else {

View File

@@ -1,5 +1,5 @@
use super::GitRepo; use super::GitRepo;
use anyhow::{bail, Result}; use anyhow::{Result, bail};
use chrono::Local; use chrono::Local;
/// Commit builder for creating commits /// Commit builder for creating commits
@@ -119,10 +119,14 @@ impl CommitBuilder {
return Ok(msg.clone()); return Ok(msg.clone());
} }
let commit_type = self.commit_type.as_ref() let commit_type = self
.commit_type
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Commit type is required"))?; .ok_or_else(|| anyhow::anyhow!("Commit type is required"))?;
let description = self.description.as_ref() let description = self
.description
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Description is required"))?; .ok_or_else(|| anyhow::anyhow!("Description is required"))?;
let message = match self.format { let message = match self.format {
@@ -166,45 +170,46 @@ impl CommitBuilder {
fn amend_commit(&self, repo: &GitRepo, message: &str) -> Result<()> { fn amend_commit(&self, repo: &GitRepo, message: &str) -> Result<()> {
use std::process::Command; use std::process::Command;
let mut args = vec!["commit", "--amend"]; let mut args = vec!["commit", "--amend"];
if self.no_verify { if self.no_verify {
args.push("--no-verify"); args.push("--no-verify");
} }
args.push("-m"); args.push("-m");
args.push(message); args.push(message);
if self.sign { if self.sign {
args.push("-S"); args.push("-S");
} }
let output = Command::new("git") let output = Command::new("git")
.args(&args) .args(&args)
.current_dir(repo.path()) .current_dir(repo.path())
.output()?; .output()?;
if !output.status.success() { if !output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout); let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
let error_msg = if stderr.is_empty() { let error_msg = if stderr.is_empty() {
if stdout.is_empty() { if stdout.is_empty() {
"GPG signing failed. Please check:\n\ "GPG signing failed. Please check:\n\
1. GPG signing key is configured (git config --get user.signingkey)\n\ 1. GPG signing key is configured (git config --get user.signingkey)\n\
2. GPG agent is running\n\ 2. GPG agent is running\n\
3. You can sign commits manually (try: git commit --amend -S)".to_string() 3. You can sign commits manually (try: git commit --amend -S)"
.to_string()
} else { } else {
stdout.to_string() stdout.to_string()
} }
} else { } else {
stderr.to_string() stderr.to_string()
}; };
bail!("Failed to amend commit: {}", error_msg); bail!("Failed to amend commit: {}", error_msg);
} }
Ok(()) Ok(())
} }
} }
@@ -219,7 +224,7 @@ impl Default for CommitBuilder {
pub fn create_date_commit_message(prefix: Option<&str>) -> String { pub fn create_date_commit_message(prefix: Option<&str>) -> String {
let now = Local::now(); let now = Local::now();
let date_str = now.format("%Y-%m-%d").to_string(); let date_str = now.format("%Y-%m-%d").to_string();
match prefix { match prefix {
Some(p) => format!("{}: {}", p, date_str), Some(p) => format!("{}: {}", p, date_str),
None => format!("chore: update {}", date_str), None => format!("chore: update {}", date_str),
@@ -229,58 +234,65 @@ pub fn create_date_commit_message(prefix: Option<&str>) -> String {
/// Commit type suggestions based on diff /// Commit type suggestions based on diff
pub fn suggest_commit_type(diff: &str) -> Vec<&'static str> { pub fn suggest_commit_type(diff: &str) -> Vec<&'static str> {
let mut suggestions = vec![]; let mut suggestions = vec![];
// Check for test files // Check for test files
if diff.contains("test") || diff.contains("spec") || diff.contains("__tests__") { if diff.contains("test") || diff.contains("spec") || diff.contains("__tests__") {
suggestions.push("test"); suggestions.push("test");
} }
// Check for documentation // Check for documentation
if diff.contains("README") || diff.contains(".md") || diff.contains("docs/") { if diff.contains("README") || diff.contains(".md") || diff.contains("docs/") {
suggestions.push("docs"); suggestions.push("docs");
} }
// Check for configuration files // Check for configuration files
if diff.contains("config") || diff.contains(".json") || diff.contains(".yaml") || diff.contains(".toml") { if diff.contains("config")
|| diff.contains(".json")
|| diff.contains(".yaml")
|| diff.contains(".toml")
{
suggestions.push("chore"); suggestions.push("chore");
} }
// Check for dependencies // Check for dependencies
if diff.contains("Cargo.toml") || diff.contains("package.json") || diff.contains("requirements.txt") { if diff.contains("Cargo.toml")
|| diff.contains("package.json")
|| diff.contains("requirements.txt")
{
suggestions.push("build"); suggestions.push("build");
} }
// Check for CI // Check for CI
if diff.contains(".github/") || diff.contains(".gitlab-") || diff.contains("Jenkinsfile") { if diff.contains(".github/") || diff.contains(".gitlab-") || diff.contains("Jenkinsfile") {
suggestions.push("ci"); suggestions.push("ci");
} }
// Default suggestions // Default suggestions
if suggestions.is_empty() { if suggestions.is_empty() {
suggestions.extend(&["feat", "fix", "refactor"]); suggestions.extend(&["feat", "fix", "refactor"]);
} }
suggestions suggestions
} }
/// Parse existing commit message /// Parse existing commit message
pub fn parse_commit_message(message: &str) -> ParsedCommit { pub fn parse_commit_message(message: &str) -> ParsedCommit {
let lines: Vec<&str> = message.lines().collect(); let lines: Vec<&str> = message.lines().collect();
if lines.is_empty() { if lines.is_empty() {
return ParsedCommit::default(); return ParsedCommit::default();
} }
let first_line = lines[0]; let first_line = lines[0];
// Try to parse as conventional commit // Try to parse as conventional commit
if let Some(colon_pos) = first_line.find(':') { if let Some(colon_pos) = first_line.find(':') {
let type_part = &first_line[..colon_pos]; let type_part = &first_line[..colon_pos];
let description = first_line[colon_pos + 1..].trim(); let description = first_line[colon_pos + 1..].trim();
let breaking = type_part.ends_with('!'); let breaking = type_part.ends_with('!');
let type_part = type_part.trim_end_matches('!'); let type_part = type_part.trim_end_matches('!');
let (commit_type, scope) = if let Some(open) = type_part.find('(') { let (commit_type, scope) = if let Some(open) = type_part.find('(') {
if let Some(close) = type_part.find(')') { if let Some(close) = type_part.find(')') {
let t = &type_part[..open]; let t = &type_part[..open];
@@ -292,42 +304,51 @@ pub fn parse_commit_message(message: &str) -> ParsedCommit {
} else { } else {
(Some(type_part.to_string()), None) (Some(type_part.to_string()), None)
}; };
// Extract body and footer // Extract body and footer
let mut body_lines = vec![]; let mut body_lines = vec![];
let mut footer_lines = vec![]; let mut footer_lines = vec![];
let mut in_footer = false; let mut in_footer = false;
for line in &lines[1..] { for line in &lines[1..] {
if line.trim().is_empty() { if line.trim().is_empty() {
continue; continue;
} }
if line.starts_with("BREAKING CHANGE:") || if line.starts_with("BREAKING CHANGE:")
line.starts_with("Closes") || || line.starts_with("Closes")
line.starts_with("Fixes") || || line.starts_with("Fixes")
line.starts_with("Refs") || || line.starts_with("Refs")
line.starts_with("Co-authored-by:") { || line.starts_with("Co-authored-by:")
{
in_footer = true; in_footer = true;
} }
if in_footer { if in_footer {
footer_lines.push(line.to_string()); footer_lines.push(line.to_string());
} else { } else {
body_lines.push(line.to_string()); body_lines.push(line.to_string());
} }
} }
return ParsedCommit { return ParsedCommit {
commit_type, commit_type,
scope, scope,
description: Some(description.to_string()), description: Some(description.to_string()),
body: if body_lines.is_empty() { None } else { Some(body_lines.join("\n")) }, body: if body_lines.is_empty() {
footer: if footer_lines.is_empty() { None } else { Some(footer_lines.join("\n")) }, None
} else {
Some(body_lines.join("\n"))
},
footer: if footer_lines.is_empty() {
None
} else {
Some(footer_lines.join("\n"))
},
breaking, breaking,
}; };
} }
// Non-conventional commit // Non-conventional commit
ParsedCommit { ParsedCommit {
description: Some(first_line.to_string()), description: Some(first_line.to_string()),
@@ -351,7 +372,7 @@ impl ParsedCommit {
pub fn to_message(&self, format: crate::config::CommitFormat) -> String { pub fn to_message(&self, format: crate::config::CommitFormat) -> String {
let commit_type = self.commit_type.as_deref().unwrap_or("chore"); let commit_type = self.commit_type.as_deref().unwrap_or("chore");
let description = self.description.as_deref().unwrap_or("update"); let description = self.description.as_deref().unwrap_or("update");
match format { match format {
crate::config::CommitFormat::Conventional => { crate::config::CommitFormat::Conventional => {
crate::utils::formatter::format_conventional_commit( crate::utils::formatter::format_conventional_commit(

View File

@@ -1,54 +1,49 @@
use anyhow::{bail, Context, Result}; use anyhow::{Context, Result, bail};
use git2::{Repository, Signature, StatusOptions, Config, Oid, ObjectType}; use git2::{Config, ObjectType, Oid, Repository, Signature, StatusOptions};
use std::path::{Path, PathBuf, Component};
use std::collections::HashMap; use std::collections::HashMap;
use tempfile; use std::path::{Component, Path, PathBuf};
pub mod changelog; pub mod changelog;
pub mod commit; pub mod commit;
pub mod tag; pub mod tag;
#[cfg(target_os = "windows")]
use std::os::windows::ffi::OsStringExt;
fn normalize_path_for_git2(path: &Path) -> PathBuf { fn normalize_path_for_git2(path: &Path) -> PathBuf {
let mut normalized = path.to_path_buf(); let mut normalized = path.to_path_buf();
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
{ {
let path_str = path.to_string_lossy(); let path_str = path.to_string_lossy();
if path_str.starts_with(r"\\?\") { if path_str.starts_with(r"\\?\")
if let Some(stripped) = path_str.strip_prefix(r"\\?\") { && let Some(stripped) = path_str.strip_prefix(r"\\?\")
normalized = PathBuf::from(stripped); {
} normalized = PathBuf::from(stripped);
} }
if path_str.starts_with(r"\\?\UNC\") { if path_str.starts_with(r"\\?\UNC\")
if let Some(stripped) = path_str.strip_prefix(r"\\?\UNC\") { && let Some(stripped) = path_str.strip_prefix(r"\\?\UNC\")
normalized = PathBuf::from(format!(r"\\{}", stripped)); {
} normalized = PathBuf::from(format!(r"\\{}", stripped));
} }
} }
normalized normalized
} }
fn get_absolute_path<P: AsRef<Path>>(path: P) -> Result<PathBuf> { fn get_absolute_path<P: AsRef<Path>>(path: P) -> Result<PathBuf> {
let path = path.as_ref(); let path = path.as_ref();
if path.is_absolute() { if path.is_absolute() {
return Ok(normalize_path_for_git2(path)); return Ok(normalize_path_for_git2(path));
} }
let current_dir = std::env::current_dir() let current_dir = std::env::current_dir().with_context(|| "Failed to get current directory")?;
.with_context(|| "Failed to get current directory")?;
let absolute = current_dir.join(path); let absolute = current_dir.join(path);
Ok(normalize_path_for_git2(&absolute)) Ok(normalize_path_for_git2(&absolute))
} }
fn resolve_path_without_canonicalize(path: &Path) -> PathBuf { fn resolve_path_without_canonicalize(path: &Path) -> PathBuf {
let mut components = Vec::new(); let mut components = Vec::new();
for component in path.components() { for component in path.components() {
match component { match component {
Component::ParentDir => { Component::ParentDir => {
@@ -62,62 +57,62 @@ fn resolve_path_without_canonicalize(path: &Path) -> PathBuf {
_ => components.push(component), _ => components.push(component),
} }
} }
let mut result = PathBuf::new(); let mut result = PathBuf::new();
for component in components { for component in components {
result.push(component.as_os_str()); result.push(component.as_os_str());
} }
normalize_path_for_git2(&result) normalize_path_for_git2(&result)
} }
fn try_open_repo_with_git2(path: &Path) -> Result<Repository> { fn try_open_repo_with_git2(path: &Path) -> Result<Repository> {
let normalized = normalize_path_for_git2(path); let normalized = normalize_path_for_git2(path);
let discover_opts = git2::RepositoryOpenFlags::empty(); let discover_opts = git2::RepositoryOpenFlags::empty();
let ceiling_dirs: [&str; 0] = []; let ceiling_dirs: [&str; 0] = [];
let repo = Repository::open_ext(&normalized, discover_opts, &ceiling_dirs) let repo = Repository::open_ext(&normalized, discover_opts, ceiling_dirs)
.or_else(|_| Repository::discover(&normalized)) .or_else(|_| Repository::discover(&normalized))
.or_else(|_| Repository::open(&normalized)); .or_else(|_| Repository::open(&normalized));
repo.map_err(|e| anyhow::anyhow!("git2 failed: {}", e)) repo.map_err(|e| anyhow::anyhow!("git2 failed: {}", e))
} }
fn try_open_repo_with_git_cli(path: &Path) -> Result<Repository> { fn try_open_repo_with_git_cli(path: &Path) -> Result<Repository> {
let output = std::process::Command::new("git") let output = std::process::Command::new("git")
.args(&["rev-parse", "--show-toplevel"]) .args(["rev-parse", "--show-toplevel"])
.current_dir(path) .current_dir(path)
.output() .output()
.context("Failed to execute git command")?; .context("Failed to execute git command")?;
if !output.status.success() { if !output.status.success() {
bail!("git CLI failed to find repository"); bail!("git CLI failed to find repository");
} }
let stdout = String::from_utf8_lossy(&output.stdout); let stdout = String::from_utf8_lossy(&output.stdout);
let git_root = stdout.trim(); let git_root = stdout.trim();
if git_root.is_empty() { if git_root.is_empty() {
bail!("git CLI returned empty path"); bail!("git CLI returned empty path");
} }
let git_root_path = PathBuf::from(git_root); let git_root_path = PathBuf::from(git_root);
let normalized = normalize_path_for_git2(&git_root_path); let normalized = normalize_path_for_git2(&git_root_path);
Repository::open(&normalized) Repository::open(&normalized)
.with_context(|| format!("Failed to open repo from git CLI path: {:?}", normalized)) .with_context(|| format!("Failed to open repo from git CLI path: {:?}", normalized))
} }
fn diagnose_repo_issue(path: &Path) -> String { fn diagnose_repo_issue(path: &Path) -> String {
let mut issues = Vec::new(); let mut issues = Vec::new();
if !path.exists() { if !path.exists() {
issues.push(format!("Path does not exist: {:?}", path)); issues.push(format!("Path does not exist: {:?}", path));
} else if !path.is_dir() { } else if !path.is_dir() {
issues.push(format!("Path is not a directory: {:?}", path)); issues.push(format!("Path is not a directory: {:?}", path));
} }
let git_dir = path.join(".git"); let git_dir = path.join(".git");
if git_dir.exists() { if git_dir.exists() {
if git_dir.is_dir() { if git_dir.is_dir() {
@@ -133,7 +128,7 @@ fn diagnose_repo_issue(path: &Path) -> String {
} }
} else { } else {
issues.push("No .git found in current directory".to_string()); issues.push("No .git found in current directory".to_string());
let mut current = path; let mut current = path;
let mut depth = 0; let mut depth = 0;
while let Some(parent) = current.parent() { while let Some(parent) = current.parent() {
@@ -149,7 +144,7 @@ fn diagnose_repo_issue(path: &Path) -> String {
current = parent; current = parent;
} }
} }
#[cfg(target_os = "windows")] #[cfg(target_os = "windows")]
{ {
let path_str = path.to_string_lossy(); let path_str = path.to_string_lossy();
@@ -160,11 +155,11 @@ fn diagnose_repo_issue(path: &Path) -> String {
issues.push("WARNING: Path has mixed path separators".to_string()); issues.push("WARNING: Path has mixed path separators".to_string());
} }
} }
if let Ok(current_dir) = std::env::current_dir() { if let Ok(current_dir) = std::env::current_dir() {
issues.push(format!("Current working directory: {:?}", current_dir)); issues.push(format!("Current working directory: {:?}", current_dir));
} }
issues.join("\n ") issues.join("\n ")
} }
@@ -177,17 +172,15 @@ pub struct GitRepo {
impl GitRepo { impl GitRepo {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> { pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref(); let path = path.as_ref();
let absolute_path = get_absolute_path(path)?; let absolute_path = get_absolute_path(path)?;
let resolved_path = resolve_path_without_canonicalize(&absolute_path); let resolved_path = resolve_path_without_canonicalize(&absolute_path);
let repo = try_open_repo_with_git2(&resolved_path) let repo = try_open_repo_with_git2(&resolved_path).or_else(|git2_err| {
.or_else(|git2_err| { try_open_repo_with_git_cli(&resolved_path).map_err(|cli_err| {
try_open_repo_with_git_cli(&resolved_path) let diagnosis = diagnose_repo_issue(&resolved_path);
.map_err(|cli_err| { anyhow::anyhow!(
let diagnosis = diagnose_repo_issue(&resolved_path); "Failed to open git repository:\n\
anyhow::anyhow!(
"Failed to open git repository:\n\
\n\ \n\
=== git2 Error ===\n {}\n\ === git2 Error ===\n {}\n\
\n\ \n\
@@ -200,17 +193,20 @@ impl GitRepo {
2. Run: git status (to verify git works)\n\ 2. Run: git status (to verify git works)\n\
3. Run: git config --global --add safe.directory \"*\"\n\ 3. Run: git config --global --add safe.directory \"*\"\n\
4. Check file permissions", 4. Check file permissions",
git2_err, cli_err, diagnosis git2_err,
) cli_err,
}) diagnosis
})?; )
})
let repo_path = repo.workdir() })?;
let repo_path = repo
.workdir()
.map(|p| p.to_path_buf()) .map(|p| p.to_path_buf())
.unwrap_or_else(|| resolved_path.clone()); .unwrap_or_else(|| resolved_path.clone());
let config = repo.config().ok(); let config = repo.config().ok();
Ok(Self { Ok(Self {
repo, repo,
path: normalize_path_for_git2(&repo_path), path: normalize_path_for_git2(&repo_path),
@@ -251,7 +247,11 @@ impl GitRepo {
pub fn get_user_name(&self) -> Result<String> { pub fn get_user_name(&self) -> Result<String> {
self.get_config("user.name")? self.get_config("user.name")?
.or_else(|| std::env::var("GIT_AUTHOR_NAME").ok()) .or_else(|| std::env::var("GIT_AUTHOR_NAME").ok())
.ok_or_else(|| anyhow::anyhow!("User name not configured. Set it with: git config user.name \"Your Name\"")) .ok_or_else(|| {
anyhow::anyhow!(
"User name not configured. Set it with: git config user.name \"Your Name\""
)
})
} }
/// Get the configured user email /// Get the configured user email
@@ -263,7 +263,8 @@ impl GitRepo {
/// Get the configured GPG signing key /// Get the configured GPG signing key
pub fn get_signing_key(&self) -> Result<Option<String>> { pub fn get_signing_key(&self) -> Result<Option<String>> {
Ok(self.get_config("user.signingkey")? Ok(self
.get_config("user.signingkey")?
.or_else(|| std::env::var("GIT_SIGNING_KEY").ok())) .or_else(|| std::env::var("GIT_SIGNING_KEY").ok()))
} }
@@ -290,13 +291,9 @@ impl GitRepo {
if let Some(program) = self.get_config("gpg.program")? { if let Some(program) = self.get_config("gpg.program")? {
return Ok(program); return Ok(program);
} }
let default_gpg = if cfg!(windows) { let default_gpg = if cfg!(windows) { "gpg.exe" } else { "gpg" };
"gpg.exe"
} else {
"gpg"
};
Ok(default_gpg.to_string()) Ok(default_gpg.to_string())
} }
@@ -304,10 +301,13 @@ impl GitRepo {
pub fn create_signature(&self) -> Result<Signature<'_>> { pub fn create_signature(&self) -> Result<Signature<'_>> {
let name = self.get_user_name()?; let name = self.get_user_name()?;
let email = self.get_user_email()?; let email = self.get_user_email()?;
let time = git2::Time::new(std::time::SystemTime::now() let time = git2::Time::new(
.duration_since(std::time::UNIX_EPOCH) std::time::SystemTime::now()
.unwrap() .duration_since(std::time::UNIX_EPOCH)
.as_secs() as i64, 0); .unwrap()
.as_secs() as i64,
0,
);
Signature::new(&name, &email, &time).map_err(Into::into) Signature::new(&name, &email, &time).map_err(Into::into)
} }
@@ -332,7 +332,7 @@ impl GitRepo {
pub fn get_staged_diff(&self) -> Result<String> { pub fn get_staged_diff(&self) -> Result<String> {
// Use git CLI to get staged diff for better compatibility // Use git CLI to get staged diff for better compatibility
let output = std::process::Command::new("git") let output = std::process::Command::new("git")
.args(&["diff", "--cached"]) .args(["diff", "--cached"])
.current_dir(&self.path) .current_dir(&self.path)
.output() .output()
.with_context(|| "Failed to get staged diff with git command")?; .with_context(|| "Failed to get staged diff with git command")?;
@@ -346,6 +346,128 @@ impl GitRepo {
Ok(diff_text) Ok(diff_text)
} }
/// Get staged diff with files sorted by importance
/// Important files (source code) come first, then config files like Cargo.toml,
/// then lock files like Cargo.lock
pub fn get_staged_diff_sorted(&self) -> Result<String> {
let diff = self.get_staged_diff()?;
if diff.is_empty() {
return Ok(diff);
}
let mut file_diffs = Vec::new();
let mut current_file_diff = String::new();
let mut current_file = String::new();
for line in diff.lines() {
if line.starts_with("diff --git") {
// Save previous file diff if any
if !current_file_diff.is_empty() && !current_file.is_empty() {
file_diffs.push((current_file.clone(), current_file_diff.clone()));
}
current_file = extract_file_from_diff_line(line);
current_file_diff = format!("{}\n", line);
} else {
current_file_diff.push_str(line);
current_file_diff.push('\n');
}
}
// Add the last file diff
if !current_file_diff.is_empty() && !current_file.is_empty() {
file_diffs.push((current_file, current_file_diff));
}
// Sort by file importance
file_diffs.sort_by(|a, b| {
let score_a = file_importance_score(&a.0);
let score_b = file_importance_score(&b.0);
score_b.cmp(&score_a) // Descending order
});
// Combine sorted diffs
let sorted_diff: String = file_diffs.into_iter().map(|(_, diff)| diff).collect();
Ok(sorted_diff)
}
}
/// Extract filename from diff --git line
fn extract_file_from_diff_line(line: &str) -> String {
// Format: "diff --git a/path/to/file b/path/to/file"
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 3 {
// Return the second path (after b/)
if let Some(path) = parts[2].strip_prefix("b/") {
return path.to_string();
}
// Fallback to first path (after a/)
if let Some(path) = parts[1].strip_prefix("a/") {
return path.to_string();
}
}
line.to_string()
}
/// Calculate file importance score
/// Higher score = more important
fn file_importance_score(filename: &str) -> i32 {
// Priority list for important file types
let important_extensions = [
".rs", ".py", ".js", ".ts", ".tsx", ".jsx", ".go", ".java", ".cpp", ".c", ".rust", ".vue",
".svelte", ".html", ".css", ".scss", ".sass", ".less",
];
// Config files that are important but less than source code
let config_files = [
"Cargo.toml",
"package.json",
"go.mod",
"go.sum",
"pom.xml",
"Makefile",
"CMakeLists.txt",
"build.gradle",
"gradle.properties",
];
// Lock files - lowest priority
let lock_files = [
"Cargo.lock",
"package-lock.json",
"yarn.lock",
"pnpm-lock.yaml",
"Gemfile.lock",
"composer.lock",
];
// Check lock files first (lowest priority)
for lock in lock_files.iter() {
if filename.ends_with(lock) {
return 1;
}
}
// Check config files (medium priority)
for config in config_files.iter() {
if filename.ends_with(config) {
return 2;
}
}
// Check important source files (highest priority)
for ext in important_extensions.iter() {
if filename.ends_with(ext) {
return 3;
}
}
// Default priority for other files
2
}
impl GitRepo {
/// Get unstaged diff /// Get unstaged diff
pub fn get_unstaged_diff(&self) -> Result<String> { pub fn get_unstaged_diff(&self) -> Result<String> {
let diff = self.repo.diff_index_to_workdir(None, None)?; let diff = self.repo.diff_index_to_workdir(None, None)?;
@@ -390,18 +512,21 @@ impl GitRepo {
/// Get list of staged files /// Get list of staged files
pub fn get_staged_files(&self) -> Result<Vec<String>> { pub fn get_staged_files(&self) -> Result<Vec<String>> {
let statuses = self.repo.statuses(Some( let statuses = self
StatusOptions::new() .repo
.include_untracked(false), .statuses(Some(StatusOptions::new().include_untracked(false)))?;
))?;
let mut files = vec![]; let mut files = vec![];
for entry in statuses.iter() { for entry in statuses.iter() {
let status = entry.status(); let status = entry.status();
if status.is_index_new() || status.is_index_modified() || status.is_index_deleted() || status.is_index_renamed() || status.is_index_typechange() { if (status.is_index_new()
if let Some(path) = entry.path() { || status.is_index_modified()
files.push(path.to_string()); || status.is_index_deleted()
} || status.is_index_renamed()
|| status.is_index_typechange())
&& let Some(path) = entry.path()
{
files.push(path.to_string());
} }
} }
@@ -431,7 +556,7 @@ impl GitRepo {
pub fn stage_all(&self) -> Result<()> { pub fn stage_all(&self) -> Result<()> {
// Use git command for reliable staging (handles all edge cases) // Use git command for reliable staging (handles all edge cases)
let output = std::process::Command::new("git") let output = std::process::Command::new("git")
.args(&["add", "-A"]) .args(["add", "-A"])
.current_dir(&self.path) .current_dir(&self.path)
.output() .output()
.with_context(|| "Failed to stage changes with git command")?; .with_context(|| "Failed to stage changes with git command")?;
@@ -515,27 +640,28 @@ impl GitRepo {
std::fs::write(temp_file.path(), message)?; std::fs::write(temp_file.path(), message)?;
let output = std::process::Command::new("git") let output = std::process::Command::new("git")
.args(&["commit", "-S", "-F", temp_file.path().to_str().unwrap()]) .args(["commit", "-S", "-F", temp_file.path().to_str().unwrap()])
.current_dir(&self.path) .current_dir(&self.path)
.output()?; .output()?;
if !output.status.success() { if !output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout); let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
let error_msg = if stderr.is_empty() { let error_msg = if stderr.is_empty() {
if stdout.is_empty() { if stdout.is_empty() {
"GPG signing failed. Please check:\n\ "GPG signing failed. Please check:\n\
1. GPG signing key is configured (git config --get user.signingkey)\n\ 1. GPG signing key is configured (git config --get user.signingkey)\n\
2. GPG agent is running\n\ 2. GPG agent is running\n\
3. You can sign commits manually (try: git commit -S -m 'test')".to_string() 3. You can sign commits manually (try: git commit -S -m 'test')"
.to_string()
} else { } else {
stdout.to_string() stdout.to_string()
} }
} else { } else {
stderr.to_string() stderr.to_string()
}; };
bail!("Failed to create signed commit: {}", error_msg); bail!("Failed to create signed commit: {}", error_msg);
} }
@@ -548,7 +674,8 @@ impl GitRepo {
let head = self.repo.head()?; let head = self.repo.head()?;
if head.is_branch() { if head.is_branch() {
let name = head.shorthand() let name = head
.shorthand()
.ok_or_else(|| anyhow::anyhow!("Invalid branch name"))?; .ok_or_else(|| anyhow::anyhow!("Invalid branch name"))?;
Ok(name.to_string()) Ok(name.to_string())
} else { } else {
@@ -559,7 +686,8 @@ impl GitRepo {
/// Get current commit hash (short) /// Get current commit hash (short)
pub fn current_commit_short(&self) -> Result<String> { pub fn current_commit_short(&self) -> Result<String> {
let head = self.repo.head()?; let head = self.repo.head()?;
let oid = head.target() let oid = head
.target()
.ok_or_else(|| anyhow::anyhow!("No target for HEAD"))?; .ok_or_else(|| anyhow::anyhow!("No target for HEAD"))?;
Ok(oid.to_string()[..8].to_string()) Ok(oid.to_string()[..8].to_string())
} }
@@ -567,7 +695,8 @@ impl GitRepo {
/// Get current commit hash (full) /// Get current commit hash (full)
pub fn current_commit(&self) -> Result<String> { pub fn current_commit(&self) -> Result<String> {
let head = self.repo.head()?; let head = self.repo.head()?;
let oid = head.target() let oid = head
.target()
.ok_or_else(|| anyhow::anyhow!("No target for HEAD"))?; .ok_or_else(|| anyhow::anyhow!("No target for HEAD"))?;
Ok(oid.to_string()) Ok(oid.to_string())
} }
@@ -642,12 +771,16 @@ impl GitRepo {
name: name.to_string(), name: name.to_string(),
target: oid.to_string(), target: oid.to_string(),
message: commit.message().unwrap_or("").to_string(), message: commit.message().unwrap_or("").to_string(),
time: commit.time().seconds(),
}); });
} }
true true
})?; })?;
// Sort tags by time (newest first)
tags.sort_by(|a, b| b.time.cmp(&a.time));
Ok(tags) Ok(tags)
} }
@@ -662,13 +795,7 @@ impl GitRepo {
if sign { if sign {
self.create_signed_tag_with_git2(name, msg, &sig, target.id())?; self.create_signed_tag_with_git2(name, msg, &sig, target.id())?;
} else { } else {
self.repo.tag( self.repo.tag(name, target.as_object(), &sig, msg, false)?;
name,
target.as_object(),
&sig,
msg,
false,
)?;
} }
} else { } else {
self.repo.tag( self.repo.tag(
@@ -684,9 +811,15 @@ impl GitRepo {
} }
/// Create signed tag using git CLI /// Create signed tag using git CLI
fn create_signed_tag_with_git2(&self, name: &str, message: &str, _signature: &Signature, _target_id: Oid) -> Result<()> { fn create_signed_tag_with_git2(
&self,
name: &str,
message: &str,
_signature: &Signature,
_target_id: Oid,
) -> Result<()> {
let output = std::process::Command::new("git") let output = std::process::Command::new("git")
.args(&["tag", "-s", name, "-m", message]) .args(["tag", "-s", name, "-m", message])
.current_dir(&self.path) .current_dir(&self.path)
.output()?; .output()?;
@@ -699,7 +832,12 @@ impl GitRepo {
} }
/// Create GPG signature for arbitrary content /// Create GPG signature for arbitrary content
fn create_gpg_signature_for_content(&self, _content: &str, _gpg_program: &str, _signing_key: &str) -> Result<String> { fn create_gpg_signature_for_content(
&self,
_content: &str,
_gpg_program: &str,
_signing_key: &str,
) -> Result<String> {
Ok(String::new()) Ok(String::new())
} }
@@ -712,7 +850,7 @@ impl GitRepo {
/// Push to remote /// Push to remote
pub fn push(&self, remote: &str, refspec: &str) -> Result<()> { pub fn push(&self, remote: &str, refspec: &str) -> Result<()> {
let output = std::process::Command::new("git") let output = std::process::Command::new("git")
.args(&["push", remote, refspec]) .args(["push", remote, refspec])
.current_dir(&self.path) .current_dir(&self.path)
.output()?; .output()?;
@@ -727,7 +865,8 @@ impl GitRepo {
/// Get remote URL /// Get remote URL
pub fn get_remote_url(&self, remote: &str) -> Result<String> { pub fn get_remote_url(&self, remote: &str) -> Result<String> {
let remote_obj = self.repo.find_remote(remote)?; let remote_obj = self.repo.find_remote(remote)?;
let url = remote_obj.url() let url = remote_obj
.url()
.ok_or_else(|| anyhow::anyhow!("Remote has no URL"))?; .ok_or_else(|| anyhow::anyhow!("Remote has no URL"))?;
Ok(url.to_string()) Ok(url.to_string())
} }
@@ -741,7 +880,7 @@ impl GitRepo {
pub fn status_summary(&self) -> Result<StatusSummary> { pub fn status_summary(&self) -> Result<StatusSummary> {
// Use git CLI for more reliable status detection // Use git CLI for more reliable status detection
let output = std::process::Command::new("git") let output = std::process::Command::new("git")
.args(&["status", "--porcelain"]) .args(["status", "--porcelain"])
.current_dir(&self.path) .current_dir(&self.path)
.output() .output()
.with_context(|| "Failed to get status with git command")?; .with_context(|| "Failed to get status with git command")?;
@@ -778,9 +917,10 @@ impl GitRepo {
} }
// Conflicted files (both columns are U or DD, AA, etc.) // Conflicted files (both columns are U or DD, AA, etc.)
if (index_status == 'U' || worktree_status == 'U') || if (index_status == 'U' || worktree_status == 'U')
(index_status == 'A' && worktree_status == 'A') || || (index_status == 'A' && worktree_status == 'A')
(index_status == 'D' && worktree_status == 'D') { || (index_status == 'D' && worktree_status == 'D')
{
conflicted += 1; conflicted += 1;
} }
} }
@@ -832,6 +972,7 @@ pub struct TagInfo {
pub name: String, pub name: String,
pub target: String, pub target: String,
pub message: String, pub message: String,
pub time: i64,
} }
/// Repository status summary /// Repository status summary
@@ -870,52 +1011,51 @@ impl StatusSummary {
pub fn find_repo<P: AsRef<Path>>(start_path: P) -> Result<GitRepo> { pub fn find_repo<P: AsRef<Path>>(start_path: P) -> Result<GitRepo> {
let start_path = start_path.as_ref(); let start_path = start_path.as_ref();
let absolute_start = get_absolute_path(start_path)?; let absolute_start = get_absolute_path(start_path)?;
let resolved_start = resolve_path_without_canonicalize(&absolute_start); let resolved_start = resolve_path_without_canonicalize(&absolute_start);
if let Ok(repo) = GitRepo::open(&resolved_start) { if let Ok(repo) = GitRepo::open(&resolved_start) {
return Ok(repo); return Ok(repo);
} }
let mut current = resolved_start.as_path(); let mut current = resolved_start.as_path();
let mut attempted_paths = vec![current.to_string_lossy().to_string()]; let mut attempted_paths = vec![current.to_string_lossy().to_string()];
let max_depth = 50; let max_depth = 50;
let mut depth = 0; let mut depth = 0;
while let Some(parent) = current.parent() { while let Some(parent) = current.parent() {
depth += 1; depth += 1;
if depth > max_depth { if depth > max_depth {
break; break;
} }
attempted_paths.push(parent.to_string_lossy().to_string()); attempted_paths.push(parent.to_string_lossy().to_string());
if let Ok(repo) = GitRepo::open(parent) { if let Ok(repo) = GitRepo::open(parent) {
return Ok(repo); return Ok(repo);
} }
current = parent; current = parent;
} }
if let Ok(output) = std::process::Command::new("git") if let Ok(output) = std::process::Command::new("git")
.args(&["rev-parse", "--show-toplevel"]) .args(["rev-parse", "--show-toplevel"])
.current_dir(&resolved_start) .current_dir(&resolved_start)
.output() .output()
&& output.status.success()
{ {
if output.status.success() { let stdout = String::from_utf8_lossy(&output.stdout);
let stdout = String::from_utf8_lossy(&output.stdout); let git_root = stdout.trim();
let git_root = stdout.trim(); if !git_root.is_empty()
if !git_root.is_empty() { && let Ok(repo) = GitRepo::open(git_root)
if let Ok(repo) = GitRepo::open(git_root) { {
return Ok(repo); return Ok(repo);
}
}
} }
} }
let diagnosis = diagnose_repo_issue(&resolved_start); let diagnosis = diagnose_repo_issue(&resolved_start);
bail!( bail!(
"No git repository found.\n\ "No git repository found.\n\
\n\ \n\
@@ -1037,6 +1177,105 @@ impl<'a> GitConfigHelper<'a> {
} }
} }
/// Configuration source indicator
#[derive(Debug, Clone, PartialEq)]
pub enum ConfigSource {
Local,
Global,
NotSet,
}
impl std::fmt::Display for ConfigSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigSource::Local => write!(f, "local"),
ConfigSource::Global => write!(f, "global"),
ConfigSource::NotSet => write!(f, "not set"),
}
}
}
/// Single configuration entry with source information
#[derive(Debug, Clone)]
pub struct ConfigEntry {
pub value: Option<String>,
pub source: ConfigSource,
pub local_value: Option<String>,
pub global_value: Option<String>,
}
impl ConfigEntry {
pub fn new(local: Option<String>, global: Option<String>) -> Self {
let (value, source) = match (&local, &global) {
(Some(_), _) => (local.clone(), ConfigSource::Local),
(None, Some(_)) => (global.clone(), ConfigSource::Global),
(None, None) => (None, ConfigSource::NotSet),
};
Self {
value,
source,
local_value: local,
global_value: global,
}
}
pub fn is_set(&self) -> bool {
self.value.is_some()
}
pub fn is_local(&self) -> bool {
self.source == ConfigSource::Local
}
pub fn is_global(&self) -> bool {
self.source == ConfigSource::Global
}
}
/// Merged user configuration with local/global source tracking
#[derive(Debug, Clone)]
pub struct MergedUserConfig {
pub name: ConfigEntry,
pub email: ConfigEntry,
pub signing_key: ConfigEntry,
pub ssh_command: ConfigEntry,
pub commit_gpgsign: ConfigEntry,
pub tag_gpgsign: ConfigEntry,
}
impl MergedUserConfig {
pub fn from_repo(repo: &Repository) -> Result<Self> {
let local_config = repo.config().ok();
let global_config = git2::Config::open_default().ok();
let get_entry = |key: &str| -> ConfigEntry {
let local = local_config.as_ref().and_then(|c| c.get_string(key).ok());
let global = global_config.as_ref().and_then(|c| c.get_string(key).ok());
ConfigEntry::new(local, global)
};
Ok(Self {
name: get_entry("user.name"),
email: get_entry("user.email"),
signing_key: get_entry("user.signingkey"),
ssh_command: get_entry("core.sshCommand"),
commit_gpgsign: get_entry("commit.gpgsign"),
tag_gpgsign: get_entry("tag.gpgsign"),
})
}
pub fn is_complete(&self) -> bool {
self.name.is_set() && self.email.is_set()
}
pub fn has_local_overrides(&self) -> bool {
self.name.is_local()
|| self.email.is_local()
|| self.signing_key.is_local()
|| self.ssh_command.is_local()
}
}
/// User configuration for git /// User configuration for git
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct UserConfig { pub struct UserConfig {
@@ -1060,23 +1299,38 @@ impl UserConfig {
diffs.push(ConfigDiff { diffs.push(ConfigDiff {
key: "user.name".to_string(), key: "user.name".to_string(),
left: self.name.clone().unwrap_or_else(|| "<not set>".to_string()), left: self.name.clone().unwrap_or_else(|| "<not set>".to_string()),
right: other.name.clone().unwrap_or_else(|| "<not set>".to_string()), right: other
.name
.clone()
.unwrap_or_else(|| "<not set>".to_string()),
}); });
} }
if self.email != other.email { if self.email != other.email {
diffs.push(ConfigDiff { diffs.push(ConfigDiff {
key: "user.email".to_string(), key: "user.email".to_string(),
left: self.email.clone().unwrap_or_else(|| "<not set>".to_string()), left: self
right: other.email.clone().unwrap_or_else(|| "<not set>".to_string()), .email
.clone()
.unwrap_or_else(|| "<not set>".to_string()),
right: other
.email
.clone()
.unwrap_or_else(|| "<not set>".to_string()),
}); });
} }
if self.signing_key != other.signing_key { if self.signing_key != other.signing_key {
diffs.push(ConfigDiff { diffs.push(ConfigDiff {
key: "user.signingkey".to_string(), key: "user.signingkey".to_string(),
left: self.signing_key.clone().unwrap_or_else(|| "<not set>".to_string()), left: self
right: other.signing_key.clone().unwrap_or_else(|| "<not set>".to_string()), .signing_key
.clone()
.unwrap_or_else(|| "<not set>".to_string()),
right: other
.signing_key
.clone()
.unwrap_or_else(|| "<not set>".to_string()),
}); });
} }

View File

@@ -1,5 +1,5 @@
use super::GitRepo; use super::GitRepo;
use anyhow::{bail, Result}; use anyhow::{Result, bail};
use semver::Version; use semver::Version;
/// Tag builder for creating tags /// Tag builder for creating tags
@@ -69,19 +69,19 @@ impl TagBuilder {
/// Build tag message /// Build tag message
pub fn build_message(&self) -> Result<String> { pub fn build_message(&self) -> Result<String> {
let message = self.message.as_ref() let message = self.message.as_ref().cloned().unwrap_or_else(|| {
.cloned() let name = self.name.as_deref().unwrap_or("unknown");
.unwrap_or_else(|| { format!("Release {}", name)
let name = self.name.as_deref().unwrap_or("unknown"); });
format!("Release {}", name)
});
Ok(message) Ok(message)
} }
/// Execute tag creation /// Execute tag creation
pub fn execute(&self, repo: &GitRepo) -> Result<()> { pub fn execute(&self, repo: &GitRepo) -> Result<()> {
let name = self.name.as_ref() let name = self
.name
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Tag name is required"))?; .ok_or_else(|| anyhow::anyhow!("Tag name is required"))?;
if !self.force { if !self.force {
@@ -105,10 +105,10 @@ impl TagBuilder {
/// Execute and push tag /// Execute and push tag
pub fn execute_and_push(&self, repo: &GitRepo, remote: &str) -> Result<()> { pub fn execute_and_push(&self, repo: &GitRepo, remote: &str) -> Result<()> {
self.execute(repo)?; self.execute(repo)?;
let name = self.name.as_ref().unwrap(); let name = self.name.as_ref().unwrap();
repo.push(remote, &format!("refs/tags/{}", name))?; repo.push(remote, &format!("refs/tags/{}", name))?;
Ok(()) Ok(())
} }
} }
@@ -136,7 +136,10 @@ impl VersionBump {
"minor" => Ok(Self::Minor), "minor" => Ok(Self::Minor),
"patch" => Ok(Self::Patch), "patch" => Ok(Self::Patch),
"prerelease" | "pre" => Ok(Self::Prerelease), "prerelease" | "pre" => Ok(Self::Prerelease),
_ => bail!("Invalid version bump: {}. Use: major, minor, patch, prerelease", s), _ => bail!(
"Invalid version bump: {}. Use: major, minor, patch, prerelease",
s
),
} }
} }
@@ -149,7 +152,7 @@ impl VersionBump {
/// Get latest version tag from repository /// Get latest version tag from repository
pub fn get_latest_version(repo: &GitRepo, prefix: &str) -> Result<Option<Version>> { pub fn get_latest_version(repo: &GitRepo, prefix: &str) -> Result<Option<Version>> {
let tags = repo.get_tags()?; let tags = repo.get_tags()?;
let mut versions: Vec<Version> = tags let mut versions: Vec<Version> = tags
.iter() .iter()
.filter_map(|t| { .filter_map(|t| {
@@ -158,9 +161,9 @@ pub fn get_latest_version(repo: &GitRepo, prefix: &str) -> Result<Option<Version
Version::parse(version_str).ok() Version::parse(version_str).ok()
}) })
.collect(); .collect();
versions.sort_by(|a, b| b.cmp(a)); // Descending order versions.sort_by(|a, b| b.cmp(a)); // Descending order
Ok(versions.into_iter().next()) Ok(versions.into_iter().next())
} }
@@ -183,14 +186,17 @@ pub fn suggest_version_bump(commits: &[super::CommitInfo]) -> VersionBump {
let mut has_breaking = false; let mut has_breaking = false;
let mut has_feature = false; let mut has_feature = false;
let mut has_fix = false; let mut has_fix = false;
for commit in commits { for commit in commits {
let msg = commit.message.to_lowercase(); let msg = commit.message.to_lowercase();
if msg.contains("breaking change") || msg.contains("breaking-change") || msg.contains("breaking_change") { if msg.contains("breaking change")
|| msg.contains("breaking-change")
|| msg.contains("breaking_change")
{
has_breaking = true; has_breaking = true;
} }
if let Some(commit_type) = commit.commit_type() { if let Some(commit_type) = commit.commit_type() {
match commit_type.as_str() { match commit_type.as_str() {
"feat" => has_feature = true, "feat" => has_feature = true,
@@ -199,7 +205,7 @@ pub fn suggest_version_bump(commits: &[super::CommitInfo]) -> VersionBump {
} }
} }
} }
if has_breaking { if has_breaking {
VersionBump::Major VersionBump::Major
} else if has_feature { } else if has_feature {
@@ -214,20 +220,20 @@ pub fn suggest_version_bump(commits: &[super::CommitInfo]) -> VersionBump {
/// Generate tag message from commits /// Generate tag message from commits
pub fn generate_tag_message(version: &str, commits: &[super::CommitInfo]) -> String { pub fn generate_tag_message(version: &str, commits: &[super::CommitInfo]) -> String {
let mut message = format!("Release {}\n\n", version); let mut message = format!("Release {}\n\n", version);
// Group commits by type // Group commits by type
let mut features = vec![]; let mut features = vec![];
let mut fixes = vec![]; let mut fixes = vec![];
let mut other = vec![]; let mut other = vec![];
let mut breaking = vec![]; let mut breaking = vec![];
for commit in commits { for commit in commits {
let subject = commit.subject(); let subject = commit.subject();
if commit.message.contains("BREAKING CHANGE") { if commit.message.contains("BREAKING CHANGE") {
breaking.push(subject.to_string()); breaking.push(subject.to_string());
} }
if let Some(commit_type) = commit.commit_type() { if let Some(commit_type) = commit.commit_type() {
match commit_type.as_str() { match commit_type.as_str() {
"feat" => features.push(subject.to_string()), "feat" => features.push(subject.to_string()),
@@ -238,7 +244,7 @@ pub fn generate_tag_message(version: &str, commits: &[super::CommitInfo]) -> Str
other.push(subject.to_string()); other.push(subject.to_string());
} }
} }
// Build message // Build message
if !breaking.is_empty() { if !breaking.is_empty() {
message.push_str("## Breaking Changes\n\n"); message.push_str("## Breaking Changes\n\n");
@@ -247,7 +253,7 @@ pub fn generate_tag_message(version: &str, commits: &[super::CommitInfo]) -> Str
} }
message.push('\n'); message.push('\n');
} }
if !features.is_empty() { if !features.is_empty() {
message.push_str("## Features\n\n"); message.push_str("## Features\n\n");
for item in &features { for item in &features {
@@ -255,7 +261,7 @@ pub fn generate_tag_message(version: &str, commits: &[super::CommitInfo]) -> Str
} }
message.push('\n'); message.push('\n');
} }
if !fixes.is_empty() { if !fixes.is_empty() {
message.push_str("## Bug Fixes\n\n"); message.push_str("## Bug Fixes\n\n");
for item in &fixes { for item in &fixes {
@@ -263,36 +269,36 @@ pub fn generate_tag_message(version: &str, commits: &[super::CommitInfo]) -> Str
} }
message.push('\n'); message.push('\n');
} }
if !other.is_empty() { if !other.is_empty() {
message.push_str("## Other Changes\n\n"); message.push_str("## Other Changes\n\n");
for item in &other { for item in &other {
message.push_str(&format!("- {}\n", item)); message.push_str(&format!("- {}\n", item));
} }
} }
message message
} }
/// Tag deletion helper /// Tag deletion helper
pub fn delete_tag(repo: &GitRepo, name: &str, remote: Option<&str>) -> Result<()> { pub fn delete_tag(repo: &GitRepo, name: &str, remote: Option<&str>) -> Result<()> {
repo.delete_tag(name)?; repo.delete_tag(name)?;
if let Some(remote) = remote { if let Some(remote) = remote {
use std::process::Command; use std::process::Command;
let refspec = format!(":refs/tags/{}", name); let refspec = format!(":refs/tags/{}", name);
let output = Command::new("git") let output = Command::new("git")
.args(&["push", remote, &refspec]) .args(["push", remote, &refspec])
.current_dir(repo.path()) .current_dir(repo.path())
.output()?; .output()?;
if !output.status.success() { if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
bail!("Failed to delete remote tag: {}", stderr); bail!("Failed to delete remote tag: {}", stderr);
} }
} }
Ok(()) Ok(())
} }
@@ -303,7 +309,7 @@ pub fn list_tags(
limit: Option<usize>, limit: Option<usize>,
) -> Result<Vec<super::TagInfo>> { ) -> Result<Vec<super::TagInfo>> {
let tags = repo.get_tags()?; let tags = repo.get_tags()?;
let filtered: Vec<_> = tags let filtered: Vec<_> = tags
.into_iter() .into_iter()
.filter(|t| { .filter(|t| {
@@ -314,7 +320,7 @@ pub fn list_tags(
} }
}) })
.collect(); .collect();
if let Some(limit) = limit { if let Some(limit) = limit {
Ok(filtered.into_iter().take(limit).collect()) Ok(filtered.into_iter().take(limit).collect())
} else { } else {

View File

@@ -267,7 +267,9 @@ impl Messages {
Language::Chinese => "没有可提交的更改。工作树是干净的。", Language::Chinese => "没有可提交的更改。工作树是干净的。",
Language::Japanese => "コミットする変更がありません。作業ツリーはクリーンです。", Language::Japanese => "コミットする変更がありません。作業ツリーはクリーンです。",
Language::Korean => "커밋할 변경 사항이 없습니다. 작업 트리가 깨끗합니다.", Language::Korean => "커밋할 변경 사항이 없습니다. 작업 트리가 깨끗합니다.",
Language::Spanish => "No hay cambios para hacer commit. El árbol de trabajo está limpio.", Language::Spanish => {
"No hay cambios para hacer commit. El árbol de trabajo está limpio."
}
Language::French => "Aucun changement à commiter. L'arbre de travail est propre.", Language::French => "Aucun changement à commiter. L'arbre de travail est propre.",
Language::German => "Keine Änderungen zum Committen. Arbeitsbaum ist sauber.", Language::German => "Keine Änderungen zum Committen. Arbeitsbaum ist sauber.",
} }
@@ -289,11 +291,19 @@ impl Messages {
match self.language { match self.language {
Language::English => "No files staged. Auto-staging all changes...", Language::English => "No files staged. Auto-staging all changes...",
Language::Chinese => "没有暂存文件。自动暂存所有更改...", Language::Chinese => "没有暂存文件。自动暂存所有更改...",
Language::Japanese => "ステージされたファイルがありません。すべての変更を自動ステージ中...", Language::Japanese => {
"ステージされたファイルがありません。すべての変更を自動ステージ中..."
}
Language::Korean => "스테이징된 파일이 없습니다. 모든 변경 사항을 자동 스테이징 중...", Language::Korean => "스테이징된 파일이 없습니다. 모든 변경 사항을 자동 스테이징 중...",
Language::Spanish => "No hay archivos preparados. Preparando automáticamente todos los cambios...", Language::Spanish => {
Language::French => "Aucun fichier indexé. Indexation automatique de tous les changements...", "No hay archivos preparados. Preparando automáticamente todos los cambios..."
Language::German => "Keine Dateien bereitgestellt. Alle Änderungen werden automatisch bereitgestellt...", }
Language::French => {
"Aucun fichier indexé. Indexation automatique de tous les changements..."
}
Language::German => {
"Keine Dateien bereitgestellt. Alle Änderungen werden automatisch bereitgestellt..."
}
} }
} }
@@ -359,12 +369,23 @@ impl Messages {
pub fn ai_generating_tag(&self, count: usize) -> String { pub fn ai_generating_tag(&self, count: usize) -> String {
match self.language { match self.language {
Language::English => format!("🤖 AI is generating tag message from {} commits...", count), Language::English => {
format!("🤖 AI is generating tag message from {} commits...", count)
}
Language::Chinese => format!("🤖 AI 正在从 {} 个提交生成标签消息...", count), Language::Chinese => format!("🤖 AI 正在从 {} 个提交生成标签消息...", count),
Language::Japanese => format!("🤖 AIが{}個のコミットからタグメッセージを生成しています...", count), Language::Japanese => format!(
"🤖 AIが{}個のコミットからタグメッセージを生成しています...",
count
),
Language::Korean => format!("🤖 AI가 {}개의 커밋에서 태그 메시지를 생성 중...", count), Language::Korean => format!("🤖 AI가 {}개의 커밋에서 태그 메시지를 생성 중...", count),
Language::Spanish => format!("🤖 La IA está generando mensaje de etiqueta desde {} commits...", count), Language::Spanish => format!(
Language::French => format!("🤖 L'IA génère le message dtiquette à partir de {} commits...", count), "🤖 La IA está generando mensaje de etiqueta desde {} commits...",
count
),
Language::French => format!(
"🤖 L'IA génère le message d'étiquette à partir de {} commits...",
count
),
Language::German => format!("🤖 KI generiert Tag-Nachricht aus {} Commits...", count), Language::German => format!("🤖 KI generiert Tag-Nachricht aus {} Commits...", count),
} }
} }

View File

@@ -7,7 +7,11 @@ pub struct Translator {
} }
impl Translator { impl Translator {
pub fn new(language: Language, keep_types_english: bool, keep_changelog_types_english: bool) -> Self { pub fn new(
language: Language,
keep_types_english: bool,
keep_changelog_types_english: bool,
) -> Self {
Self { Self {
language, language,
keep_types_english, keep_types_english,
@@ -227,7 +231,11 @@ pub fn translate_commit_type(commit_type: &str, language: Language, keep_english
translator.translate_commit_type(commit_type) translator.translate_commit_type(commit_type)
} }
pub fn translate_changelog_category(category: &str, language: Language, keep_english: bool) -> String { pub fn translate_changelog_category(
category: &str,
language: Language,
keep_english: bool,
) -> String {
let translator = Translator::new(language, true, keep_english); let translator = Translator::new(language, true, keep_english);
translator.translate_changelog_category(category) translator.translate_changelog_category(category)
} }

View File

@@ -1,7 +1,9 @@
use super::{create_http_client, LlmProvider}; use super::thinking::ThinkingStateManager;
use anyhow::{bail, Context, Result}; use super::{LlmProvider, create_http_client};
use anyhow::{Context, Result, bail};
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
/// Anthropic Claude API client /// Anthropic Claude API client
@@ -9,6 +11,12 @@ pub struct AnthropicClient {
api_key: String, api_key: String,
model: String, model: String,
client: reqwest::Client, client: reqwest::Client,
thinking_enabled: bool,
thinking_budget_tokens: u32,
max_tokens: u32,
temperature: f32,
top_p: Option<f32>,
thinking_state: Option<Arc<ThinkingStateManager>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@@ -17,24 +25,59 @@ struct MessagesRequest {
max_tokens: u32, max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>, temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
messages: Vec<AnthropicMessage>, messages: Vec<AnthropicMessage>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>, system: Option<Vec<SystemContent>>,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<ThinkingConfig>,
stream: bool,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Clone)]
struct SystemContent {
#[serde(rename = "type")]
content_type: String,
text: String,
}
#[derive(Debug, Serialize)]
struct ThinkingConfig {
#[serde(rename = "type")]
thinking_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
budget_tokens: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct AnthropicMessage { struct AnthropicMessage {
role: String, role: String,
content: String, content: AnthropicContent,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(untagged)]
enum AnthropicContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct ContentBlock {
#[serde(rename = "type")]
content_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
text: Option<String>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct MessagesResponse { struct MessagesResponse {
content: Vec<ContentBlock>, content: Vec<ResponseContentBlock>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct ContentBlock { struct ResponseContentBlock {
#[serde(rename = "type")] #[serde(rename = "type")]
content_type: String, content_type: String,
text: String, text: String,
@@ -52,31 +95,112 @@ struct AnthropicError {
message: String, message: String,
} }
// --- Streaming SSE event structures ---
#[derive(Debug, Deserialize)]
struct SseEvent {
#[serde(rename = "type")]
event_type: String,
#[serde(default)]
message: Option<SseMessage>,
#[serde(default)]
index: Option<u32>,
#[serde(default)]
content_block: Option<SseContentBlock>,
#[serde(default)]
delta: Option<SseDelta>,
#[serde(default)]
usage: Option<SseUsage>,
}
#[derive(Debug, Deserialize)]
struct SseMessage {
#[serde(default)]
content: Option<Vec<SseContentBlock>>,
}
#[derive(Debug, Deserialize)]
struct SseContentBlock {
#[serde(rename = "type")]
content_type: String,
#[serde(default)]
thinking: Option<String>,
#[serde(default)]
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct SseDelta {
#[serde(rename = "type")]
delta_type: Option<String>,
#[serde(default)]
thinking: Option<String>,
#[serde(default)]
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct SseUsage {
#[serde(default)]
output_tokens: Option<u32>,
}
impl AnthropicClient { impl AnthropicClient {
/// Create new Anthropic client
pub fn new(api_key: &str, model: &str) -> Result<Self> { pub fn new(api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?; let client = create_http_client(Duration::from_secs(60))?;
Ok(Self { Ok(Self {
api_key: api_key.to_string(), api_key: api_key.to_string(),
model: model.to_string(), model: model.to_string(),
client, client,
thinking_enabled: false,
thinking_budget_tokens: 1024,
max_tokens: 500,
temperature: 0.7,
top_p: None,
thinking_state: None,
}) })
} }
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> { pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?; self.client = create_http_client(timeout)?;
Ok(self) Ok(self)
} }
/// List available models pub fn with_thinking(mut self, enabled: bool) -> Self {
self.thinking_enabled = enabled;
self
}
pub fn with_thinking_budget_tokens(mut self, budget_tokens: u32) -> Self {
self.thinking_budget_tokens = budget_tokens;
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
self.thinking_state = Some(state);
self
}
pub async fn list_models(&self) -> Result<Vec<String>> { pub async fn list_models(&self) -> Result<Vec<String>> {
// Anthropic doesn't have a models API endpoint, return predefined list
Ok(ANTHROPIC_MODELS.iter().map(|&m| m.to_string()).collect()) Ok(ANTHROPIC_MODELS.iter().map(|&m| m.to_string()).collect())
} }
/// Validate API key
pub async fn validate_key(&self) -> Result<bool> { pub async fn validate_key(&self) -> Result<bool> {
let url = "https://api.anthropic.com/v1/messages"; let url = "https://api.anthropic.com/v1/messages";
@@ -84,14 +208,18 @@ impl AnthropicClient {
model: self.model.clone(), model: self.model.clone(),
max_tokens: 5, max_tokens: 5,
temperature: Some(0.0), temperature: Some(0.0),
top_p: None,
messages: vec![AnthropicMessage { messages: vec![AnthropicMessage {
role: "user".to_string(), role: "user".to_string(),
content: "Hi".to_string(), content: AnthropicContent::Text("Hi".to_string()),
}], }],
system: None, system: None,
thinking: None,
stream: false,
}; };
let response = self.client let response = self
.client
.post(url) .post(url)
.header("x-api-key", &self.api_key) .header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01") .header("anthropic-version", "2023-06-01")
@@ -124,25 +252,28 @@ impl LlmProvider for AnthropicClient {
async fn generate(&self, prompt: &str) -> Result<String> { async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![AnthropicMessage { let messages = vec![AnthropicMessage {
role: "user".to_string(), role: "user".to_string(),
content: prompt.to_string(), content: AnthropicContent::Text(prompt.to_string()),
}]; }];
self.messages_request(messages, None).await self.messages_request_with_retry(messages, None).await
} }
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> { async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let messages = vec![AnthropicMessage { let messages = vec![AnthropicMessage {
role: "user".to_string(), role: "user".to_string(),
content: user.to_string(), content: AnthropicContent::Text(user.to_string()),
}]; }];
let system = if system.is_empty() { let system = if system.is_empty() {
None None
} else { } else {
Some(system.to_string()) Some(vec![SystemContent {
content_type: "text".to_string(),
text: system.to_string(),
}])
}; };
self.messages_request(messages, system).await self.messages_request_with_retry(messages, system).await
} }
async fn is_available(&self) -> bool { async fn is_available(&self) -> bool {
@@ -155,22 +286,84 @@ impl LlmProvider for AnthropicClient {
} }
impl AnthropicClient { impl AnthropicClient {
async fn messages_request_with_retry(
&self,
messages: Vec<AnthropicMessage>,
system: Option<Vec<SystemContent>>,
) -> Result<String> {
let mut last_error = None;
for attempt in 1..=3 {
match self
.messages_request(messages.clone(), system.clone())
.await
{
Ok(result) => return Ok(result),
Err(e) => {
let err_msg = e.to_string();
let is_retryable = err_msg.contains("timeout")
|| err_msg.contains("connection")
|| err_msg.contains("temporary")
|| err_msg.contains("5")
&& (err_msg.contains("500")
|| err_msg.contains("502")
|| err_msg.contains("503")
|| err_msg.contains("504"));
if !is_retryable || attempt == 3 {
last_error = Some(e);
break;
}
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
}
}
}
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
}
async fn messages_request( async fn messages_request(
&self, &self,
messages: Vec<AnthropicMessage>, messages: Vec<AnthropicMessage>,
system: Option<String>, system: Option<Vec<SystemContent>>,
) -> Result<String> {
if self.thinking_enabled {
self.streaming_messages_request(messages, system).await
} else {
self.non_streaming_messages_request(messages, system).await
}
}
async fn non_streaming_messages_request(
&self,
messages: Vec<AnthropicMessage>,
system: Option<Vec<SystemContent>>,
) -> Result<String> { ) -> Result<String> {
let url = "https://api.anthropic.com/v1/messages"; let url = "https://api.anthropic.com/v1/messages";
let temperature = if self.temperature == 0.0 {
None
} else {
Some(self.temperature)
};
let request = MessagesRequest { let request = MessagesRequest {
model: self.model.clone(), model: self.model.clone(),
max_tokens: 500, max_tokens: self.max_tokens,
temperature: Some(0.7), temperature,
top_p: self.top_p,
messages, messages,
system, system,
thinking: Some(ThinkingConfig {
thinking_type: "disabled".to_string(),
budget_tokens: None,
}),
stream: false,
}; };
let response = self.client let response = self
.client
.post(url) .post(url)
.header("x-api-key", &self.api_key) .header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01") .header("anthropic-version", "2023-06-01")
@@ -179,35 +372,205 @@ impl AnthropicClient {
.send() .send()
.await .await
.context("Failed to send request to Anthropic")?; .context("Failed to send request to Anthropic")?;
let status = response.status(); let status = response.status();
if !status.is_success() { if !status.is_success() {
let text = response.text().await.unwrap_or_default(); let text = response.text().await.unwrap_or_default();
// Try to parse error
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) { if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!("Anthropic API error: {} ({})", error.error.message, error.error.error_type); bail!(
"Anthropic API error: {} ({})",
error.error.message,
error.error.error_type
);
} }
bail!("Anthropic API error: {} - {}", status, text); bail!("Anthropic API error: {} - {}", status, text);
} }
let result: MessagesResponse = response let result: MessagesResponse = response
.json() .json()
.await .await
.context("Failed to parse Anthropic response")?; .context("Failed to parse Anthropic response")?;
result.content result
.content
.into_iter() .into_iter()
.find(|c| c.content_type == "text") .find(|c| c.content_type == "text")
.map(|c| c.text.trim().to_string()) .map(|c| c.text.trim().to_string())
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("No text response from Anthropic")) .ok_or_else(|| anyhow::anyhow!("No text response from Anthropic"))
} }
/// Streaming request for thinking mode, filters thinking content blocks
async fn streaming_messages_request(
&self,
messages: Vec<AnthropicMessage>,
system: Option<Vec<SystemContent>>,
) -> Result<String> {
let url = "https://api.anthropic.com/v1/messages";
let thinking = ThinkingConfig {
thinking_type: "enabled".to_string(),
budget_tokens: Some(self.thinking_budget_tokens),
};
// max_tokens must exceed budget_tokens
let max_tokens = (self.max_tokens).max(self.thinking_budget_tokens + 100);
let request = MessagesRequest {
model: self.model.clone(),
max_tokens,
temperature: None, // must be omitted for thinking mode
top_p: None,
messages,
system,
thinking: Some(thinking),
stream: true,
};
let response = self
.client
.post(url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.json(&request)
.send()
.await
.context("Failed to send streaming request to Anthropic")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"Anthropic API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("Anthropic API error: {} - {}", status, text);
}
let mut content_buffer = String::new();
let mut in_thinking = false;
let mut has_reasoning = false;
let mut has_content = false;
let thinking_state = self.thinking_state.as_ref();
let mut byte_stream = response.bytes_stream();
let mut line_buffer = String::new();
use futures_util::StreamExt;
while let Some(chunk) = byte_stream.next().await {
let chunk = chunk.context("Failed to read streaming response chunk")?;
let chunk_str =
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
line_buffer.push_str(&chunk_str);
while let Some(line_end) = line_buffer.find('\n') {
let line = line_buffer[..line_end].trim().to_string();
line_buffer = line_buffer[line_end + 1..].to_string();
if line.is_empty() {
continue;
}
// Parse SSE event line
if let Some(data) = line.strip_prefix("data: ") {
if let Ok(event) = serde_json::from_str::<SseEvent>(data) {
match event.event_type.as_str() {
"content_block_start" => {
if let Some(ref block) = event.content_block {
if block.content_type == "thinking" {
in_thinking = true;
if !has_reasoning {
has_reasoning = true;
if let Some(state) = thinking_state {
state.start_thinking();
}
}
}
}
}
"content_block_delta" => {
if let Some(ref delta) = event.delta {
// Thinking delta - ignore content but track state
if delta.thinking.is_some() {
continue;
}
// Text delta - collect
if in_thinking && delta.text.is_some() {
// Transition from thinking to text
if let Some(state) = thinking_state {
state.end_thinking();
}
in_thinking = false;
}
if let Some(ref text) = delta.text
&& !text.is_empty()
{
has_content = true;
content_buffer.push_str(text);
}
}
}
"content_block_stop" => {
if in_thinking {
if let Some(state) = thinking_state {
state.end_thinking();
}
in_thinking = false;
}
}
_ => {}
}
}
}
}
}
// Ensure thinking state is ended
if let Some(state) = thinking_state {
state.end_thinking();
}
let result = content_buffer.trim().to_string();
if result.is_empty() {
if has_reasoning && !has_content {
bail!(
"Anthropic returned thinking content but no final answer. \
The model may have entered an incomplete thinking state. \
Please try again or disable thinking mode."
);
}
bail!(
"No response from Anthropic. \
If thinking mode is enabled, try disabling it or ensure the model supports it."
);
}
Ok(result)
}
} }
/// Available Anthropic models /// Available Anthropic models (Claude 4 series with extended thinking)
pub const ANTHROPIC_MODELS: &[&str] = &[ pub const ANTHROPIC_MODELS: &[&str] = &[
"claude-opus-4-7",
"claude-sonnet-4-6",
"claude-haiku-4-5",
// Legacy models
"claude-3-opus-20240229", "claude-3-opus-20240229",
"claude-3-sonnet-20240229", "claude-3-sonnet-20240229",
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
@@ -216,7 +579,6 @@ pub const ANTHROPIC_MODELS: &[&str] = &[
"claude-instant-1.2", "claude-instant-1.2",
]; ];
/// Check if a model name is valid
pub fn is_valid_model(model: &str) -> bool { pub fn is_valid_model(model: &str) -> bool {
ANTHROPIC_MODELS.contains(&model) ANTHROPIC_MODELS.contains(&model)
} }
@@ -226,8 +588,68 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_model_validation() { fn test_model_validation_claude4() {
assert!(is_valid_model("claude-opus-4-7"));
assert!(is_valid_model("claude-sonnet-4-6"));
assert!(is_valid_model("claude-haiku-4-5"));
assert!(is_valid_model("claude-3-sonnet-20240229")); assert!(is_valid_model("claude-3-sonnet-20240229"));
assert!(!is_valid_model("invalid-model")); assert!(!is_valid_model("invalid-model"));
} }
#[test]
fn test_thinking_config_serialization() {
let config = ThinkingConfig {
thinking_type: "enabled".to_string(),
budget_tokens: Some(2048),
};
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains(r#""type":"enabled""#));
assert!(json.contains(r#""budget_tokens":2048"#));
}
#[test]
fn test_thinking_config_disabled_serialization() {
let config = ThinkingConfig {
thinking_type: "disabled".to_string(),
budget_tokens: None,
};
let json = serde_json::to_string(&config).unwrap();
assert_eq!(json, r#"{"type":"disabled"}"#);
}
#[test]
fn test_system_content_serialization() {
let content = SystemContent {
content_type: "text".to_string(),
text: "You are helpful.".to_string(),
};
let json = serde_json::to_string(&content).unwrap();
assert!(json.contains(r#""type":"text""#));
}
#[test]
fn test_sse_event_parsing_content_block_start() {
let json = r#"{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}"#;
let event: SseEvent = serde_json::from_str(json).unwrap();
assert_eq!(event.event_type, "content_block_start");
assert_eq!(event.content_block.unwrap().content_type, "thinking");
}
#[test]
fn test_sse_event_parsing_text_delta() {
let json = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
let event: SseEvent = serde_json::from_str(json).unwrap();
assert_eq!(event.event_type, "content_block_delta");
assert_eq!(event.delta.unwrap().text, Some("Hello".to_string()));
}
#[test]
fn test_anthropic_content_text() {
let msg = AnthropicMessage {
role: "user".to_string(),
content: AnthropicContent::Text("Hello".to_string()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""content":"Hello""#));
}
} }

View File

@@ -1,7 +1,9 @@
use super::{create_http_client, LlmProvider}; use super::thinking::ThinkingStateManager;
use anyhow::{bail, Context, Result}; use super::{LlmProvider, create_http_client};
use anyhow::{Context, Result, bail};
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
/// DeepSeek API client /// DeepSeek API client
@@ -10,6 +12,11 @@ pub struct DeepSeekClient {
api_key: String, api_key: String,
model: String, model: String,
client: reqwest::Client, client: reqwest::Client,
thinking_enabled: bool,
reasoning_effort: Option<String>,
max_tokens: u32,
temperature: f32,
thinking_state: Option<Arc<ThinkingStateManager>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@@ -20,13 +27,31 @@ struct ChatCompletionRequest {
max_tokens: Option<u32>, max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>, temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
frequency_penalty: Option<f32>,
stream: bool, stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<ThinkingConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_effort: Option<String>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize)]
struct ThinkingConfig {
#[serde(rename = "type")]
thinking_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Message { struct Message {
role: String, role: String,
content: String, content: String,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@@ -37,6 +62,31 @@ struct ChatCompletionResponse {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct Choice { struct Choice {
message: Message, message: Message,
#[serde(default)]
reasoning_content: Option<String>,
}
// --- Streaming response structures ---
#[derive(Debug, Deserialize)]
struct StreamChunk {
choices: Vec<StreamChoice>,
}
#[derive(Debug, Deserialize)]
struct StreamChoice {
delta: StreamDelta,
#[serde(default)]
finish_reason: Option<String>,
index: Option<u32>,
}
#[derive(Debug, Deserialize, Default)]
struct StreamDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@@ -52,41 +102,73 @@ struct ApiError {
} }
impl DeepSeekClient { impl DeepSeekClient {
/// Create new DeepSeek client
pub fn new(api_key: &str, model: &str) -> Result<Self> { pub fn new(api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?; let client = create_http_client(Duration::from_secs(300))?;
Ok(Self { Ok(Self {
base_url: "https://api.deepseek.com/v1".to_string(), base_url: "https://api.deepseek.com".to_string(),
api_key: api_key.to_string(), api_key: api_key.to_string(),
model: model.to_string(), model: model.to_string(),
client, client,
thinking_enabled: false,
reasoning_effort: None,
max_tokens: 500,
temperature: 0.7,
thinking_state: None,
}) })
} }
/// Create with custom base URL
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> { pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?; let client = create_http_client(Duration::from_secs(300))?;
Ok(Self { Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(), base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(), api_key: api_key.to_string(),
model: model.to_string(), model: model.to_string(),
client, client,
thinking_enabled: false,
reasoning_effort: None,
max_tokens: 500,
temperature: 0.7,
thinking_state: None,
}) })
} }
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> { pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?; self.client = create_http_client(timeout)?;
Ok(self) Ok(self)
} }
/// List available models pub fn with_thinking(mut self, enabled: bool) -> Self {
self.thinking_enabled = enabled;
self
}
pub fn with_reasoning_effort(mut self, effort: Option<String>) -> Self {
self.reasoning_effort = effort;
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
self.thinking_state = Some(state);
self
}
pub async fn list_models(&self) -> Result<Vec<String>> { pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/models", self.base_url); let url = format!("{}/models", self.base_url);
let response = self.client let response = self
.client
.get(&url) .get(&url)
.header("Authorization", format!("Bearer {}", self.api_key)) .header("Authorization", format!("Bearer {}", self.api_key))
.send() .send()
@@ -101,11 +183,11 @@ impl DeepSeekClient {
#[derive(Deserialize)] #[derive(Deserialize)]
struct ModelsResponse { struct ModelsResponse {
data: Vec<Model>, data: Vec<ModelId>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct Model { struct ModelId {
id: String, id: String,
} }
@@ -117,7 +199,6 @@ impl DeepSeekClient {
Ok(result.data.into_iter().map(|m| m.id).collect()) Ok(result.data.into_iter().map(|m| m.id).collect())
} }
/// Validate API key
pub async fn validate_key(&self) -> Result<bool> { pub async fn validate_key(&self) -> Result<bool> {
match self.list_models().await { match self.list_models().await {
Ok(_) => Ok(true), Ok(_) => Ok(true),
@@ -136,32 +217,33 @@ impl DeepSeekClient {
#[async_trait] #[async_trait]
impl LlmProvider for DeepSeekClient { impl LlmProvider for DeepSeekClient {
async fn generate(&self, prompt: &str) -> Result<String> { async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![ let messages = vec![Message {
Message { role: "user".to_string(),
role: "user".to_string(), content: prompt.to_string(),
content: prompt.to_string(), reasoning_content: None,
}, }];
];
self.chat_completion_with_retry(messages).await
self.chat_completion(messages).await
} }
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> { async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![]; let mut messages = vec![];
if !system.is_empty() { if !system.is_empty() {
messages.push(Message { messages.push(Message {
role: "system".to_string(), role: "system".to_string(),
content: system.to_string(), content: system.to_string(),
reasoning_content: None,
}); });
} }
messages.push(Message { messages.push(Message {
role: "user".to_string(), role: "user".to_string(),
content: user.to_string(), content: user.to_string(),
reasoning_content: None,
}); });
self.chat_completion(messages).await self.chat_completion_with_retry(messages).await
} }
async fn is_available(&self) -> bool { async fn is_available(&self) -> bool {
@@ -174,59 +256,291 @@ impl LlmProvider for DeepSeekClient {
} }
impl DeepSeekClient { impl DeepSeekClient {
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
let mut last_error = None;
for attempt in 1..=3 {
match self.chat_completion(messages.clone()).await {
Ok(result) => return Ok(result),
Err(e) => {
let err_msg = e.to_string();
// 网络临时错误才重试
let is_retryable = err_msg.contains("timeout")
|| err_msg.contains("connection")
|| err_msg.contains("temporary")
|| err_msg.contains("5")
&& (err_msg.contains("500")
|| err_msg.contains("502")
|| err_msg.contains("503")
|| err_msg.contains("504"));
if !is_retryable || attempt == 3 {
last_error = Some(e);
break;
}
// 指数退避
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
}
}
}
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
}
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> { async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url); let url = format!("{}/chat/completions", self.base_url);
let thinking = Some(ThinkingConfig {
thinking_type: if self.thinking_enabled {
"enabled".to_string()
} else {
"disabled".to_string()
},
});
// 思考模式下temperature/top_p 等参数不应传递
// 非思考模式下可以正常传递
let (temperature, max_tokens, top_p, presence_penalty, frequency_penalty) =
if self.thinking_enabled {
(None, Some(self.max_tokens), None, None, None)
} else {
(
Some(self.temperature),
Some(self.max_tokens),
None,
None,
None,
)
};
let reasoning_effort = if self.thinking_enabled {
self.reasoning_effort.clone()
} else {
None
};
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
model: self.model.clone(), model: self.model.clone(),
messages, messages: messages.clone(),
max_tokens: Some(500), max_tokens,
temperature: Some(0.7), temperature,
stream: false, top_p,
presence_penalty,
frequency_penalty,
stream: self.thinking_enabled,
thinking,
reasoning_effort,
}; };
let response = self.client if self.thinking_enabled {
.post(&url) self.streaming_chat_completion(&url, &request).await
} else {
self.non_streaming_chat_completion(&url, &request).await
}
}
/// 非流式请求(非思考模式)
async fn non_streaming_chat_completion(
&self,
url: &str,
request: &ChatCompletionRequest,
) -> Result<String> {
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key)) .header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.json(&request) .json(request)
.send() .send()
.await .await
.context("Failed to send request to DeepSeek")?; .context("Failed to send request to DeepSeek")?;
let status = response.status(); let status = response.status();
if !status.is_success() { if !status.is_success() {
let text = response.text().await.unwrap_or_default(); let text = response.text().await.unwrap_or_default();
// Try to parse error
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) { if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!("DeepSeek API error: {} ({})", error.error.message, error.error.error_type); bail!(
"DeepSeek API error: {} ({})",
error.error.message,
error.error.error_type
);
} }
bail!("DeepSeek API error: {} - {}", status, text); bail!("DeepSeek API error: {} - {}", status, text);
} }
let result: ChatCompletionResponse = response let result: ChatCompletionResponse = response
.json() .json()
.await .await
.context("Failed to parse DeepSeek response")?; .context("Failed to parse DeepSeek response")?;
result.choices result
.choices
.into_iter() .into_iter()
.next() .next()
.map(|c| c.message.content.trim().to_string()) .map(|c| c.message.content.trim().to_string())
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("No response from DeepSeek")) .ok_or_else(|| anyhow::anyhow!("No response from DeepSeek"))
} }
/// 流式请求(思考模式),处理 reasoning_content 和 content
async fn streaming_chat_completion(
&self,
url: &str,
request: &ChatCompletionRequest,
) -> Result<String> {
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.json(request)
.send()
.await
.context("Failed to send streaming request to DeepSeek")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"DeepSeek API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("DeepSeek API error: {} - {}", status, text);
}
let mut content_buffer = String::new();
let mut has_reasoning = false;
let mut has_content = false;
let mut stream_ended = false;
let thinking_state = self.thinking_state.as_ref();
let mut byte_stream = response.bytes_stream();
let mut line_buffer = String::new();
use futures_util::StreamExt;
while let Some(chunk) = byte_stream.next().await {
let chunk = chunk.context("Failed to read streaming response chunk")?;
let chunk_str =
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
line_buffer.push_str(&chunk_str);
// 处理完整行
while let Some(line_end) = line_buffer.find('\n') {
let line = line_buffer[..line_end].trim().to_string();
line_buffer = line_buffer[line_end + 1..].to_string();
if line.is_empty() {
continue;
}
// SSE 格式data: {...} 或 data: [DONE]
if line == "data: [DONE]" {
stream_ended = true;
break;
}
if let Some(json_str) = line.strip_prefix("data: ") {
match serde_json::from_str::<StreamChunk>(json_str) {
Ok(chunk) => {
for choice in &chunk.choices {
// 处理 reasoning_content
if let Some(ref reasoning) = choice.delta.reasoning_content
&& !reasoning.is_empty()
{
if !has_reasoning {
has_reasoning = true;
if let Some(state) = thinking_state {
state.start_thinking();
}
}
// reasoning_content 不对外输出,仅用于内部状态判断
continue;
}
// 处理 content
if let Some(ref content) = choice.delta.content
&& !content.is_empty()
{
// reasoning 结束content 开始出现时移除 thinking 标识
if has_reasoning
&& !has_content
&& let Some(state) = thinking_state
{
state.end_thinking();
}
has_content = true;
content_buffer.push_str(content);
}
// 检查 finish_reason
if let Some(ref reason) = choice.finish_reason
&& reason == "stop"
{
stream_ended = true;
}
}
}
Err(_) => {
// 忽略无法解析的行(可能是心跳或注释)
}
}
}
}
if stream_ended {
break;
}
}
// 确保思考状态已结束
if let Some(state) = thinking_state {
state.end_thinking();
}
let result = content_buffer.trim().to_string();
if result.is_empty() {
if has_reasoning && !has_content {
bail!(
"DeepSeek returned reasoning content but no final answer. \
The model may have entered an incomplete thinking state. \
Please try again or disable thinking mode."
);
}
bail!(
"No response from DeepSeek. \
If thinking mode is enabled, try disabling it or ensure the model supports it."
);
}
Ok(result)
}
} }
/// Available DeepSeek models /// 可用 DeepSeek 模型列表
/// deepseek-chat / deepseek-reasoner 将于 2026-07-24 停用,推荐使用 V4 系列
pub const DEEPSEEK_MODELS: &[&str] = &[ pub const DEEPSEEK_MODELS: &[&str] = &[
"deepseek-v4-flash",
"deepseek-v4-pro",
// 兼容旧版模型 ID将于 2026-07-24 停用)
"deepseek-chat", "deepseek-chat",
"deepseek-coder", "deepseek-reasoner",
]; ];
/// Check if a model name is valid
pub fn is_valid_model(model: &str) -> bool { pub fn is_valid_model(model: &str) -> bool {
DEEPSEEK_MODELS.contains(&model) DEEPSEEK_MODELS.contains(&model)
} }
@@ -236,8 +550,73 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_model_validation() { fn test_model_validation_v4() {
assert!(is_valid_model("deepseek-v4-flash"));
assert!(is_valid_model("deepseek-v4-pro"));
assert!(is_valid_model("deepseek-chat")); assert!(is_valid_model("deepseek-chat"));
assert!(is_valid_model("deepseek-reasoner"));
assert!(!is_valid_model("invalid-model")); assert!(!is_valid_model("invalid-model"));
assert!(!is_valid_model("deepseek-v3"));
} }
}
#[test]
fn test_client_builder_defaults() {
let client = DeepSeekClient::new("test-key", "deepseek-v4-flash").unwrap();
assert!(!client.thinking_enabled);
assert_eq!(client.max_tokens, 500);
assert_eq!(client.temperature, 0.7);
assert!(client.reasoning_effort.is_none());
assert!(client.thinking_state.is_none());
}
#[test]
fn test_client_builder_with_thinking() {
let client = DeepSeekClient::new("test-key", "deepseek-v4-flash")
.unwrap()
.with_thinking(true)
.with_reasoning_effort(Some("high".to_string()))
.with_max_tokens(1000)
.with_temperature(0.5);
assert!(client.thinking_enabled);
assert_eq!(client.reasoning_effort, Some("high".to_string()));
assert_eq!(client.max_tokens, 1000);
assert_eq!(client.temperature, 0.5);
}
#[test]
fn test_thinking_config_serialization() {
let config = ThinkingConfig {
thinking_type: "enabled".to_string(),
};
let json = serde_json::to_string(&config).unwrap();
assert_eq!(json, r#"{"type":"enabled"}"#);
}
#[test]
fn test_message_serialization_without_reasoning() {
let msg = Message {
role: "user".to_string(),
content: "Hello".to_string(),
reasoning_content: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(!json.contains("reasoning_content"));
}
#[test]
fn test_stream_delta_parsing() {
let json = r#"{"content":"Hello","reasoning_content":null}"#;
let delta: StreamDelta = serde_json::from_str(json).unwrap();
assert_eq!(delta.content, Some("Hello".to_string()));
assert!(delta.reasoning_content.is_none());
}
#[test]
fn test_stream_delta_reasoning_only() {
let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#;
let delta: StreamDelta = serde_json::from_str(json).unwrap();
assert!(delta.content.is_none());
assert_eq!(delta.reasoning_content, Some("Let me think...".to_string()));
}
}

View File

@@ -1,244 +1,587 @@
use super::{create_http_client, LlmProvider}; use super::thinking::ThinkingStateManager;
use anyhow::{bail, Context, Result}; use super::{LlmProvider, create_http_client};
use async_trait::async_trait; use anyhow::{Context, Result, bail};
use serde::{Deserialize, Serialize}; use async_trait::async_trait;
use std::time::Duration; use serde::{Deserialize, Serialize};
use std::sync::Arc;
/// Kimi API client (Moonshot AI) use std::time::Duration;
pub struct KimiClient {
base_url: String, /// Kimi API client (Moonshot AI)
api_key: String, pub struct KimiClient {
model: String, base_url: String,
client: reqwest::Client, api_key: String,
} model: String,
client: reqwest::Client,
#[derive(Debug, Serialize)] thinking_enabled: bool,
struct ChatCompletionRequest { max_tokens: u32,
model: String, temperature: f32,
messages: Vec<Message>, thinking_state: Option<Arc<ThinkingStateManager>>,
#[serde(skip_serializing_if = "Option::is_none")] }
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")] #[derive(Debug, Serialize)]
temperature: Option<f32>, struct ChatCompletionRequest {
stream: bool, model: String,
} messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
#[derive(Debug, Serialize, Deserialize)] max_tokens: Option<u32>,
struct Message { #[serde(skip_serializing_if = "Option::is_none")]
role: String, temperature: Option<f32>,
content: String, stream: bool,
} #[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<ThinkingConfig>,
#[derive(Debug, Deserialize)] }
struct ChatCompletionResponse {
choices: Vec<Choice>, #[derive(Debug, Serialize)]
} struct ThinkingConfig {
#[serde(rename = "type")]
#[derive(Debug, Deserialize)] thinking_type: String,
struct Choice { }
message: Message,
} #[derive(Debug, Clone, Serialize, Deserialize)]
struct Message {
#[derive(Debug, Deserialize)] role: String,
struct ErrorResponse { content: String,
error: ApiError, #[serde(skip_serializing_if = "Option::is_none")]
} reasoning_content: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ApiError { #[derive(Debug, Deserialize)]
message: String, struct ChatCompletionResponse {
#[serde(rename = "type")] choices: Vec<Choice>,
error_type: String, }
}
#[derive(Debug, Deserialize)]
impl KimiClient { struct Choice {
/// Create new Kimi client message: Message,
pub fn new(api_key: &str, model: &str) -> Result<Self> { #[serde(default)]
let client = create_http_client(Duration::from_secs(60))?; reasoning_content: Option<String>,
}
Ok(Self {
base_url: "https://api.moonshot.cn/v1".to_string(), // --- Streaming response structures ---
api_key: api_key.to_string(),
model: model.to_string(), #[derive(Debug, Deserialize)]
client, struct StreamChunk {
}) choices: Vec<StreamChoice>,
} }
/// Create with custom base URL #[derive(Debug, Deserialize)]
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> { struct StreamChoice {
let client = create_http_client(Duration::from_secs(60))?; delta: StreamDelta,
#[serde(default)]
Ok(Self { finish_reason: Option<String>,
base_url: base_url.trim_end_matches('/').to_string(), index: Option<u32>,
api_key: api_key.to_string(), }
model: model.to_string(),
client, #[derive(Debug, Deserialize, Default)]
}) struct StreamDelta {
} #[serde(default)]
content: Option<String>,
/// Set timeout #[serde(default)]
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> { reasoning_content: Option<String>,
self.client = create_http_client(timeout)?; }
Ok(self)
} #[derive(Debug, Deserialize)]
struct ErrorResponse {
/// List available models error: ApiError,
pub async fn list_models(&self) -> Result<Vec<String>> { }
let url = format!("{}/models", self.base_url);
#[derive(Debug, Deserialize)]
let response = self.client struct ApiError {
.get(&url) message: String,
.header("Authorization", format!("Bearer {}", self.api_key)) #[serde(rename = "type")]
.send() error_type: String,
.await }
.context("Failed to list Kimi models")?;
impl KimiClient {
if !response.status().is_success() { pub fn new(api_key: &str, model: &str) -> Result<Self> {
let status = response.status(); let client = create_http_client(Duration::from_secs(300))?;
let text = response.text().await.unwrap_or_default();
bail!("Kimi API error: {} - {}", status, text); Ok(Self {
} base_url: "https://api.moonshot.cn/v1".to_string(),
api_key: api_key.to_string(),
#[derive(Deserialize)] model: model.to_string(),
struct ModelsResponse { client,
data: Vec<Model>, thinking_enabled: false,
} max_tokens: 500,
temperature: 1.0,
#[derive(Deserialize)] thinking_state: None,
struct Model { })
id: String, }
}
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
let result: ModelsResponse = response let client = create_http_client(Duration::from_secs(300))?;
.json()
.await Ok(Self {
.context("Failed to parse Kimi response")?; base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
Ok(result.data.into_iter().map(|m| m.id).collect()) model: model.to_string(),
} client,
thinking_enabled: false,
/// Validate API key max_tokens: 500,
pub async fn validate_key(&self) -> Result<bool> { temperature: 1.0,
match self.list_models().await { thinking_state: None,
Ok(_) => Ok(true), })
Err(e) => { }
let err_str = e.to_string();
if err_str.contains("401") || err_str.contains("Unauthorized") { pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
Ok(false) self.client = create_http_client(timeout)?;
} else { Ok(self)
Err(e) }
}
} pub fn with_thinking(mut self, enabled: bool) -> Self {
} self.thinking_enabled = enabled;
} self
} }
#[async_trait] pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
impl LlmProvider for KimiClient { self.max_tokens = max_tokens;
async fn generate(&self, prompt: &str) -> Result<String> { self
let messages = vec![ }
Message {
role: "user".to_string(), pub fn with_temperature(mut self, temperature: f32) -> Self {
content: prompt.to_string(), self.temperature = temperature;
}, self
]; }
self.chat_completion(messages).await pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
} self.thinking_state = Some(state);
self
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> { }
let mut messages = vec![];
pub async fn list_models(&self) -> Result<Vec<String>> {
if !system.is_empty() { let url = format!("{}/models", self.base_url);
messages.push(Message {
role: "system".to_string(), let response = self
content: system.to_string(), .client
}); .get(&url)
} .header("Authorization", format!("Bearer {}", self.api_key))
.send()
messages.push(Message { .await
role: "user".to_string(), .context("Failed to list Kimi models")?;
content: user.to_string(),
}); if !response.status().is_success() {
let status = response.status();
self.chat_completion(messages).await let text = response.text().await.unwrap_or_default();
} bail!("Kimi API error: {} - {}", status, text);
}
async fn is_available(&self) -> bool {
self.validate_key().await.unwrap_or(false) #[derive(Deserialize)]
} struct ModelsResponse {
data: Vec<ModelId>,
fn name(&self) -> &str { }
"kimi"
} #[derive(Deserialize)]
} struct ModelId {
id: String,
impl KimiClient { }
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url); let result: ModelsResponse = response
.json()
let request = ChatCompletionRequest { .await
model: self.model.clone(), .context("Failed to parse Kimi response")?;
messages,
max_tokens: Some(500), Ok(result.data.into_iter().map(|m| m.id).collect())
temperature: Some(0.7), }
stream: false,
}; pub async fn validate_key(&self) -> Result<bool> {
match self.list_models().await {
let response = self.client Ok(_) => Ok(true),
.post(&url) Err(e) => {
.header("Authorization", format!("Bearer {}", self.api_key)) let err_str = e.to_string();
.header("Content-Type", "application/json") if err_str.contains("401") || err_str.contains("Unauthorized") {
.json(&request) Ok(false)
.send() } else {
.await Err(e)
.context("Failed to send request to Kimi")?; }
}
let status = response.status(); }
}
if !status.is_success() { }
let text = response.text().await.unwrap_or_default();
#[async_trait]
// Try to parse error impl LlmProvider for KimiClient {
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) { async fn generate(&self, prompt: &str) -> Result<String> {
bail!("Kimi API error: {} ({})", error.error.message, error.error.error_type); let messages = vec![Message {
} role: "user".to_string(),
content: prompt.to_string(),
bail!("Kimi API error: {} - {}", status, text); reasoning_content: None,
} }];
let result: ChatCompletionResponse = response self.chat_completion_with_retry(messages).await
.json() }
.await
.context("Failed to parse Kimi response")?; async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![];
result.choices
.into_iter() if !system.is_empty() {
.next() messages.push(Message {
.map(|c| c.message.content.trim().to_string()) role: "system".to_string(),
.ok_or_else(|| anyhow::anyhow!("No response from Kimi")) content: system.to_string(),
} reasoning_content: None,
} });
}
/// Available Kimi models
pub const KIMI_MODELS: &[&str] = &[ messages.push(Message {
"moonshot-v1-8k", role: "user".to_string(),
"moonshot-v1-32k", content: user.to_string(),
"moonshot-v1-128k", reasoning_content: None,
]; });
/// Check if a model name is valid self.chat_completion_with_retry(messages).await
pub fn is_valid_model(model: &str) -> bool { }
KIMI_MODELS.contains(&model)
} async fn is_available(&self) -> bool {
self.validate_key().await.unwrap_or(false)
#[cfg(test)] }
mod tests {
use super::*; fn name(&self) -> &str {
"kimi"
#[test] }
fn test_model_validation() { }
assert!(is_valid_model("moonshot-v1-8k"));
assert!(!is_valid_model("invalid-model")); impl KimiClient {
} async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
} let mut last_error = None;
for attempt in 1..=3 {
match self.chat_completion(messages.clone()).await {
Ok(result) => return Ok(result),
Err(e) => {
let err_msg = e.to_string();
let is_retryable = err_msg.contains("timeout")
|| err_msg.contains("connection")
|| err_msg.contains("temporary")
|| err_msg.contains("5")
&& (err_msg.contains("500")
|| err_msg.contains("502")
|| err_msg.contains("503")
|| err_msg.contains("504"));
if !is_retryable || attempt == 3 {
last_error = Some(e);
break;
}
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
}
}
}
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
}
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
let thinking = Some(ThinkingConfig {
thinking_type: if self.thinking_enabled {
"enabled".to_string()
} else {
"disabled".to_string()
},
});
// Kimi API temperature 要求:
// - 思考模式: temperature 必须为 1.0
// - 非思考模式: temperature 必须为 0.6
let temperature = if self.thinking_enabled {
Some(1.0)
} else {
Some(0.6)
};
let request = ChatCompletionRequest {
model: self.model.clone(),
messages: messages.clone(),
max_tokens: Some(self.max_tokens),
temperature,
stream: self.thinking_enabled,
thinking,
};
if self.thinking_enabled {
self.streaming_chat_completion(&url, &request).await
} else {
self.non_streaming_chat_completion(&url, &request).await
}
}
/// 非流式请求(非思考模式)
async fn non_streaming_chat_completion(
&self,
url: &str,
request: &ChatCompletionRequest,
) -> Result<String> {
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(request)
.send()
.await
.context("Failed to send request to Kimi")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"Kimi API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("Kimi API error: {} - {}", status, text);
}
let result: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse Kimi response")?;
result
.choices
.into_iter()
.next()
.map(|c| {
let content = c.message.content.trim().to_string();
if content.is_empty() {
c.reasoning_content
.or(c.message.reasoning_content)
.map(|r| r.trim().to_string())
.unwrap_or_default()
} else {
content
}
})
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("No response from Kimi"))
}
/// 流式请求(思考模式),处理 reasoning_content 和 content
async fn streaming_chat_completion(
&self,
url: &str,
request: &ChatCompletionRequest,
) -> Result<String> {
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.json(request)
.send()
.await
.context("Failed to send streaming request to Kimi")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"Kimi API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("Kimi API error: {} - {}", status, text);
}
let mut content_buffer = String::new();
let mut has_reasoning = false;
let mut has_content = false;
let mut stream_ended = false;
let thinking_state = self.thinking_state.as_ref();
let mut byte_stream = response.bytes_stream();
let mut line_buffer = String::new();
use futures_util::StreamExt;
while let Some(chunk) = byte_stream.next().await {
let chunk = chunk.context("Failed to read streaming response chunk")?;
let chunk_str =
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
line_buffer.push_str(&chunk_str);
while let Some(line_end) = line_buffer.find('\n') {
let line = line_buffer[..line_end].trim().to_string();
line_buffer = line_buffer[line_end + 1..].to_string();
if line.is_empty() {
continue;
}
if line == "data: [DONE]" {
stream_ended = true;
break;
}
if let Some(json_str) = line.strip_prefix("data: ") {
match serde_json::from_str::<StreamChunk>(json_str) {
Ok(chunk) => {
for choice in &chunk.choices {
if let Some(ref reasoning) = choice.delta.reasoning_content
&& !reasoning.is_empty()
{
if !has_reasoning {
has_reasoning = true;
if let Some(state) = thinking_state {
state.start_thinking();
}
}
continue;
}
if let Some(ref content) = choice.delta.content
&& !content.is_empty()
{
if has_reasoning
&& !has_content
&& let Some(state) = thinking_state
{
state.end_thinking();
}
has_content = true;
content_buffer.push_str(content);
}
if let Some(ref reason) = choice.finish_reason
&& reason == "stop"
{
stream_ended = true;
}
}
}
Err(_) => {
// 忽略无法解析的行
}
}
}
}
if stream_ended {
break;
}
}
// 确保思考状态已结束
if let Some(state) = thinking_state {
state.end_thinking();
}
let result = content_buffer.trim().to_string();
if result.is_empty() {
if has_reasoning && !has_content {
bail!(
"Kimi returned reasoning content but no final answer. \
The model may have entered an incomplete thinking state. \
Please try again or disable thinking mode."
);
}
bail!(
"No response from Kimi. \
If thinking mode is enabled, try disabling it or ensure the model supports it."
);
}
Ok(result)
}
}
/// 可用 Kimi 模型列表
pub const KIMI_MODELS: &[&str] = &[
// K2 系列(推荐)
"kimi-k2.6",
"kimi-k2.5",
"kimi-k2-thinking",
"kimi-k2-thinking-turbo",
"kimi-k2-instruct",
"kimi-k2-instruct-0905",
// 兼容旧版模型 ID
"moonshot-v1-8k",
"moonshot-v1-32k",
"moonshot-v1-128k",
];
pub fn is_valid_model(model: &str) -> bool {
KIMI_MODELS.contains(&model)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_validation_k2() {
assert!(is_valid_model("kimi-k2.6"));
assert!(is_valid_model("kimi-k2.5"));
assert!(is_valid_model("kimi-k2-thinking"));
assert!(is_valid_model("kimi-k2-thinking-turbo"));
assert!(is_valid_model("moonshot-v1-8k"));
assert!(is_valid_model("moonshot-v1-32k"));
assert!(is_valid_model("moonshot-v1-128k"));
assert!(!is_valid_model("invalid-model"));
assert!(!is_valid_model("kimi-k1.5"));
}
#[test]
fn test_client_builder_defaults() {
let client = KimiClient::new("test-key", "kimi-k2.6").unwrap();
assert!(!client.thinking_enabled);
assert_eq!(client.max_tokens, 500);
assert_eq!(client.temperature, 1.0);
assert!(client.thinking_state.is_none());
}
#[test]
fn test_client_builder_with_thinking() {
let client = KimiClient::new("test-key", "kimi-k2.6")
.unwrap()
.with_thinking(true)
.with_max_tokens(1000)
.with_temperature(0.5);
assert!(client.thinking_enabled);
assert_eq!(client.max_tokens, 1000);
assert_eq!(client.temperature, 0.5);
}
#[test]
fn test_thinking_config_serialization() {
let config = ThinkingConfig {
thinking_type: "enabled".to_string(),
};
let json = serde_json::to_string(&config).unwrap();
assert_eq!(json, r#"{"type":"enabled"}"#);
}
#[test]
fn test_client_new_defaults() {
let client = KimiClient::new("test-key", "kimi-k2.6").unwrap();
assert_eq!(client.name(), "kimi");
assert!(!client.thinking_enabled);
}
#[test]
fn test_message_serialization() {
let msg = Message {
role: "user".to_string(),
content: "Hello".to_string(),
reasoning_content: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(!json.contains("reasoning_content"));
}
}

View File

@@ -1,20 +1,21 @@
use anyhow::{bail, Context, Result}; use crate::config::Language;
use anyhow::{Context, Result, bail};
use async_trait::async_trait; use async_trait::async_trait;
use std::time::Duration; use std::time::Duration;
use crate::config::Language;
pub mod anthropic;
pub mod deepseek;
pub mod kimi;
pub mod ollama; pub mod ollama;
pub mod openai; pub mod openai;
pub mod anthropic;
pub mod kimi;
pub mod deepseek;
pub mod openrouter; pub mod openrouter;
pub mod thinking;
pub use anthropic::AnthropicClient;
pub use deepseek::DeepSeekClient;
pub use kimi::KimiClient;
pub use ollama::OllamaClient; pub use ollama::OllamaClient;
pub use openai::OpenAiClient; pub use openai::OpenAiClient;
pub use anthropic::AnthropicClient;
pub use kimi::KimiClient;
pub use deepseek::DeepSeekClient;
pub use openrouter::OpenRouterClient; pub use openrouter::OpenRouterClient;
/// LLM provider trait /// LLM provider trait
@@ -22,13 +23,13 @@ pub use openrouter::OpenRouterClient;
pub trait LlmProvider: Send + Sync { pub trait LlmProvider: Send + Sync {
/// Generate text from prompt /// Generate text from prompt
async fn generate(&self, prompt: &str) -> Result<String>; async fn generate(&self, prompt: &str) -> Result<String>;
/// Generate with system prompt /// Generate with system prompt
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String>; async fn generate_with_system(&self, system: &str, user: &str) -> Result<String>;
/// Check if provider is available /// Check if provider is available
async fn is_available(&self) -> bool; async fn is_available(&self) -> bool;
/// Get provider name /// Get provider name
fn name(&self) -> &str; fn name(&self) -> &str;
} }
@@ -44,6 +45,7 @@ pub struct LlmClientConfig {
pub max_tokens: u32, pub max_tokens: u32,
pub temperature: f32, pub temperature: f32,
pub timeout: Duration, pub timeout: Duration,
pub thinking_enabled: bool,
} }
impl Default for LlmClientConfig { impl Default for LlmClientConfig {
@@ -52,53 +54,131 @@ impl Default for LlmClientConfig {
max_tokens: 500, max_tokens: 500,
temperature: 0.7, temperature: 0.7,
timeout: Duration::from_secs(30), timeout: Duration::from_secs(30),
thinking_enabled: false,
} }
} }
} }
impl LlmClient { impl LlmClient {
/// Create LLM client from configuration /// Create LLM client from configuration manager
pub async fn from_config(config: &crate::config::LlmConfig) -> Result<Self> { pub async fn from_config(manager: &crate::config::manager::ConfigManager) -> Result<Self> {
Self::from_config_with_think(manager, manager.config().llm.thinking_enabled).await
}
/// Create LLM client from configuration with explicit thinking override
pub async fn from_config_with_think(
manager: &crate::config::manager::ConfigManager,
thinking_enabled: bool,
) -> Result<Self> {
let config = manager.config();
let client_config = LlmClientConfig { let client_config = LlmClientConfig {
max_tokens: config.max_tokens, max_tokens: config.llm.max_tokens,
temperature: config.temperature, temperature: config.llm.temperature,
timeout: Duration::from_secs(config.timeout), timeout: Duration::from_secs(config.llm.timeout),
thinking_enabled,
}; };
let provider: Box<dyn LlmProvider> = match config.provider.as_str() { let provider = config.llm.provider.as_str();
"ollama" => { let model = config.llm.model.as_str();
Box::new(OllamaClient::new(&config.ollama.url, &config.ollama.model)) let base_url = manager.llm_base_url();
} let api_key = manager.get_api_key();
let provider: Box<dyn LlmProvider> = match provider {
"ollama" => Box::new(
OllamaClient::new(&base_url, model)
.with_max_tokens(client_config.max_tokens)
.with_temperature(client_config.temperature),
),
"openai" => { "openai" => {
let api_key = config.openai.api_key.as_ref() let key = api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenAI API key not configured"))?; .ok_or_else(|| anyhow::anyhow!("OpenAI API key not configured"))?;
Box::new(OpenAiClient::new( let thinking_state = if thinking_enabled {
&config.openai.base_url, Some(thinking::create_console_thinking_state())
api_key, } else {
&config.openai.model, None
)?) };
let mut client = OpenAiClient::new(&base_url, key, model)?
.with_thinking(thinking_enabled)
.with_max_tokens(client_config.max_tokens)
.with_temperature(client_config.temperature)
.with_timeout(client_config.timeout)?;
if let Some(state) = thinking_state {
client = client.with_thinking_state(state);
}
Box::new(client)
} }
"anthropic" => { "anthropic" => {
let api_key = config.anthropic.api_key.as_ref() let key = api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Anthropic API key not configured"))?; .ok_or_else(|| anyhow::anyhow!("Anthropic API key not configured"))?;
Box::new(AnthropicClient::new(api_key, &config.anthropic.model)?) let thinking_state = if thinking_enabled {
Some(thinking::create_console_thinking_state())
} else {
None
};
let budget = config.llm.thinking_budget_tokens.unwrap_or(1024);
let mut client = AnthropicClient::new(key, model)?
.with_thinking(thinking_enabled)
.with_thinking_budget_tokens(budget)
.with_max_tokens(client_config.max_tokens)
.with_temperature(client_config.temperature)
.with_timeout(client_config.timeout)?;
if let Some(state) = thinking_state {
client = client.with_thinking_state(state);
}
Box::new(client)
} }
"kimi" => { "kimi" => {
let api_key = config.kimi.api_key.as_ref() let key = api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Kimi API key not configured"))?; .ok_or_else(|| anyhow::anyhow!("Kimi API key not configured"))?;
Box::new(KimiClient::with_base_url(api_key, &config.kimi.model, &config.kimi.base_url)?) let thinking_state = if thinking_enabled {
Some(thinking::create_console_thinking_state())
} else {
None
};
let mut client = KimiClient::with_base_url(key, model, &base_url)?
.with_thinking(thinking_enabled)
.with_max_tokens(client_config.max_tokens)
.with_temperature(client_config.temperature)
.with_timeout(client_config.timeout)?;
if let Some(state) = thinking_state {
client = client.with_thinking_state(state);
}
Box::new(client)
} }
"deepseek" => { "deepseek" => {
let api_key = config.deepseek.api_key.as_ref() let key = api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("DeepSeek API key not configured"))?; .ok_or_else(|| anyhow::anyhow!("DeepSeek API key not configured"))?;
Box::new(DeepSeekClient::with_base_url(api_key, &config.deepseek.model, &config.deepseek.base_url)?) let thinking_state = if thinking_enabled {
Some(thinking::create_console_thinking_state())
} else {
None
};
let mut client = DeepSeekClient::with_base_url(key, model, &base_url)?
.with_thinking(thinking_enabled)
.with_max_tokens(client_config.max_tokens)
.with_temperature(client_config.temperature)
.with_timeout(client_config.timeout)?;
if let Some(state) = thinking_state {
client = client.with_thinking_state(state);
}
Box::new(client)
} }
"openrouter" => { "openrouter" => {
let api_key = config.openrouter.api_key.as_ref() let key = api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not configured"))?; .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not configured"))?;
Box::new(OpenRouterClient::with_base_url(api_key, &config.openrouter.model, &config.openrouter.base_url)?) Box::new(
OpenRouterClient::with_base_url(key, model, &base_url)?
.with_max_tokens(client_config.max_tokens)
.with_temperature(client_config.temperature)
.with_timeout(client_config.timeout)?,
)
} }
_ => bail!("Unknown LLM provider: {}", config.provider), _ => bail!("Unknown LLM provider: {}", provider),
}; };
Ok(Self { Ok(Self {
@@ -123,7 +203,7 @@ impl LlmClient {
language: Language, language: Language,
) -> Result<GeneratedCommit> { ) -> Result<GeneratedCommit> {
let system_prompt = get_commit_system_prompt(format, language); let system_prompt = get_commit_system_prompt(format, language);
// Add language instruction to the prompt // Add language instruction to the prompt
let language_instruction = match language { let language_instruction = match language {
Language::Chinese => "\n\n请用中文生成提交消息。", Language::Chinese => "\n\n请用中文生成提交消息。",
@@ -134,10 +214,13 @@ impl LlmClient {
Language::German => "\n\nBitte generieren Sie die Commit-Nachricht auf Deutsch.", Language::German => "\n\nBitte generieren Sie die Commit-Nachricht auf Deutsch.",
Language::English => "", Language::English => "",
}; };
let prompt = format!("{}{}", diff, language_instruction); let prompt = format!("{}{}", diff, language_instruction);
let response = self.provider.generate_with_system(system_prompt, &prompt).await?; let response = self
.provider
.generate_with_system(system_prompt, &prompt)
.await?;
self.parse_commit_response(&response, format) self.parse_commit_response(&response, format)
} }
@@ -150,7 +233,7 @@ impl LlmClient {
) -> Result<String> { ) -> Result<String> {
let system_prompt = get_tag_system_prompt(language); let system_prompt = get_tag_system_prompt(language);
let commits_text = commits.join("\n"); let commits_text = commits.join("\n");
// Add language instruction to the prompt // Add language instruction to the prompt
let language_instruction = match language { let language_instruction = match language {
Language::Chinese => "\n\n请用中文生成标签消息。", Language::Chinese => "\n\n请用中文生成标签消息。",
@@ -161,10 +244,15 @@ impl LlmClient {
Language::German => "\n\nBitte generieren Sie die Tag-Nachricht auf Deutsch.", Language::German => "\n\nBitte generieren Sie die Tag-Nachricht auf Deutsch.",
Language::English => "", Language::English => "",
}; };
let prompt = format!("Version: {}\n\nCommits:\n{}{}", version, commits_text, language_instruction); let prompt = format!(
"Version: {}\n\nCommits:\n{}{}",
self.provider.generate_with_system(system_prompt, &prompt).await version, commits_text, language_instruction
);
self.provider
.generate_with_system(system_prompt, &prompt)
.await
} }
/// Generate changelog entry /// Generate changelog entry
@@ -175,13 +263,13 @@ impl LlmClient {
language: Language, language: Language,
) -> Result<String> { ) -> Result<String> {
let system_prompt = get_changelog_system_prompt(language); let system_prompt = get_changelog_system_prompt(language);
let commits_text = commits let commits_text = commits
.iter() .iter()
.map(|(t, m)| format!("- [{}] {}", t, m)) .map(|(t, m)| format!("- [{}] {}", t, m))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n"); .join("\n");
// Add language instruction to the prompt // Add language instruction to the prompt
let language_instruction = match language { let language_instruction = match language {
Language::Chinese => "\n\n请用中文生成变更日志。", Language::Chinese => "\n\n请用中文生成变更日志。",
@@ -192,10 +280,15 @@ impl LlmClient {
Language::German => "\n\nBitte generieren Sie das Changelog auf Deutsch.", Language::German => "\n\nBitte generieren Sie das Changelog auf Deutsch.",
Language::English => "", Language::English => "",
}; };
let prompt = format!("Version: {}\n\nCommits:\n{}{}", version, commits_text, language_instruction); let prompt = format!(
"Version: {}\n\nCommits:\n{}{}",
self.provider.generate_with_system(system_prompt, &prompt).await version, commits_text, language_instruction
);
self.provider
.generate_with_system(system_prompt, &prompt)
.await
} }
/// Check if provider is available /// Check if provider is available
@@ -204,35 +297,115 @@ impl LlmClient {
} }
/// Parse commit response from LLM /// Parse commit response from LLM
fn parse_commit_response(&self, response: &str, format: crate::config::CommitFormat) -> Result<GeneratedCommit> { fn parse_commit_response(
let lines: Vec<&str> = response.lines().collect(); &self,
response: &str,
format: crate::config::CommitFormat,
) -> Result<GeneratedCommit> {
// Clean markdown code fences from the response
let cleaned = Self::strip_code_fences(response);
let lines: Vec<&str> = cleaned
.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.collect();
if lines.is_empty() { if lines.is_empty() {
bail!("Empty response from LLM"); let preview: String = response.chars().take(200).collect();
bail!(
"LLM returned empty or whitespace-only response. \
Raw response preview: '{}'. \
Hint: If using DeepSeek/Kimi with thinking enabled, \
the model may have returned reasoning_content only. \
Try disabling thinking mode or switching models.",
preview
);
} }
let first_line = lines[0]; // Find the line most likely to be the commit subject
let first_line = Self::find_commit_subject_line(&lines, format);
// Parse based on format // Parse based on format
match format { match format {
crate::config::CommitFormat::Conventional => { crate::config::CommitFormat::Conventional => {
self.parse_conventional_commit(first_line, lines) self.parse_conventional_commit(first_line, &lines, response)
} }
crate::config::CommitFormat::Commitlint => { crate::config::CommitFormat::Commitlint => {
self.parse_commitlint_commit(first_line, lines) self.parse_commitlint_commit(first_line, &lines, response)
} }
} }
} }
/// Remove surrounding markdown code fences (```) from LLM output
fn strip_code_fences(response: &str) -> String {
let mut lines: Vec<&str> = response.lines().collect();
// Strip leading fence lines (``` or ```lang)
while lines.first().map_or(false, |l| l.trim().starts_with("```")) {
lines.remove(0);
}
// Strip trailing fence lines
while lines.last().map_or(false, |l| l.trim() == "```") {
lines.pop();
}
lines.join("\n")
}
/// Find the line that is most likely the commit subject among extracted lines
fn find_commit_subject_line<'a>(
lines: &[&'a str],
format: crate::config::CommitFormat,
) -> &'a str {
let valid_types = crate::utils::validators::get_commit_types(matches!(
format,
crate::config::CommitFormat::Commitlint
));
// First pass: line starting with a known type that also has proper syntax
// (e.g. "type:", "type(scope):", "type!:")
for &line in lines {
let trimmed = line.trim();
for &t in valid_types {
if let Some(rest) = trimmed.strip_prefix(t) {
if rest.starts_with(':') || rest.starts_with('(') || rest.starts_with("!:") {
return trimmed;
}
}
}
}
// Second pass: any line containing a colon (generic "prefix: description")
for &line in lines {
if line.contains(':') {
return line.trim();
}
}
// Fallback: return the first line as-is
lines[0].trim()
}
fn parse_conventional_commit( fn parse_conventional_commit(
&self, &self,
first_line: &str, first_line: &str,
lines: Vec<&str>, lines: &[&str],
raw_response: &str,
) -> Result<GeneratedCommit> { ) -> Result<GeneratedCommit> {
// Parse: type(scope)!: description // Parse: type(scope)!: description
let parts: Vec<&str> = first_line.splitn(2, ':').collect(); let parts: Vec<&str> = first_line.splitn(2, ':').collect();
if parts.len() != 2 { if parts.len() != 2 {
bail!("Invalid conventional commit format: missing colon"); let preview: String = raw_response.chars().take(300).collect();
bail!(
"Invalid conventional commit format: missing colon.\n\
Parsed subject line: '{}'\n\
Raw response preview: '{}'\n\
Expected: <type>[optional scope]: <description>",
first_line,
preview
);
} }
let type_part = parts[0]; let type_part = parts[0];
@@ -255,7 +428,7 @@ impl LlmClient {
}; };
// Extract body and footer // Extract body and footer
let (body, footer) = self.extract_body_footer(&lines); let (body, footer) = self.extract_body_footer(lines);
Ok(GeneratedCommit { Ok(GeneratedCommit {
commit_type, commit_type,
@@ -270,12 +443,21 @@ impl LlmClient {
fn parse_commitlint_commit( fn parse_commitlint_commit(
&self, &self,
first_line: &str, first_line: &str,
lines: Vec<&str>, lines: &[&str],
raw_response: &str,
) -> Result<GeneratedCommit> { ) -> Result<GeneratedCommit> {
// Similar parsing but with commitlint rules // Similar parsing but with commitlint rules
let parts: Vec<&str> = first_line.splitn(2, ':').collect(); let parts: Vec<&str> = first_line.splitn(2, ':').collect();
if parts.len() != 2 { if parts.len() != 2 {
bail!("Invalid commit format: missing colon"); let preview: String = raw_response.chars().take(300).collect();
bail!(
"Invalid commit format: missing colon.\n\
Parsed subject line: '{}'\n\
Raw response preview: '{}'\n\
Expected: <type>[optional scope]: <subject>",
first_line,
preview
);
} }
let type_part = parts[0]; let type_part = parts[0];
@@ -321,8 +503,14 @@ impl LlmClient {
} }
// Look for footer markers // Look for footer markers
let footer_markers = ["BREAKING CHANGE:", "Closes", "Fixes", "Refs", "Co-authored-by:"]; let footer_markers = [
"BREAKING CHANGE:",
"Closes",
"Fixes",
"Refs",
"Co-authored-by:",
];
let mut body_lines = vec![]; let mut body_lines = vec![];
let mut footer_lines = vec![]; let mut footer_lines = vec![];
let mut in_footer = false; let mut in_footer = false;
@@ -331,7 +519,7 @@ impl LlmClient {
if footer_markers.iter().any(|m| line.starts_with(m)) { if footer_markers.iter().any(|m| line.starts_with(m)) {
in_footer = true; in_footer = true;
} }
if in_footer { if in_footer {
footer_lines.push(*line); footer_lines.push(*line);
} else { } else {
@@ -401,17 +589,34 @@ pub(crate) fn create_http_client(timeout: Duration) -> Result<reqwest::Client> {
} }
/// Get commit system prompt based on format and language /// Get commit system prompt based on format and language
fn get_commit_system_prompt(format: crate::config::CommitFormat, language: Language) -> &'static str { fn get_commit_system_prompt(
format: crate::config::CommitFormat,
language: Language,
) -> &'static str {
match (format, language) { match (format, language) {
(crate::config::CommitFormat::Conventional, Language::Chinese) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ZH, (crate::config::CommitFormat::Conventional, Language::Chinese) => {
(crate::config::CommitFormat::Conventional, Language::Japanese) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_JA, CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ZH
(crate::config::CommitFormat::Conventional, Language::Korean) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_KO, }
(crate::config::CommitFormat::Conventional, Language::Spanish) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ES, (crate::config::CommitFormat::Conventional, Language::Japanese) => {
(crate::config::CommitFormat::Conventional, Language::French) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_FR, CONVENTIONAL_COMMIT_SYSTEM_PROMPT_JA
(crate::config::CommitFormat::Conventional, Language::German) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_DE, }
(crate::config::CommitFormat::Conventional, Language::Korean) => {
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_KO
}
(crate::config::CommitFormat::Conventional, Language::Spanish) => {
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ES
}
(crate::config::CommitFormat::Conventional, Language::French) => {
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_FR
}
(crate::config::CommitFormat::Conventional, Language::German) => {
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_DE
}
(crate::config::CommitFormat::Conventional, _) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT, (crate::config::CommitFormat::Conventional, _) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT,
(crate::config::CommitFormat::Commitlint, Language::Chinese) => COMMITLINT_SYSTEM_PROMPT_ZH, (crate::config::CommitFormat::Commitlint, Language::Chinese) => COMMITLINT_SYSTEM_PROMPT_ZH,
(crate::config::CommitFormat::Commitlint, Language::Japanese) => COMMITLINT_SYSTEM_PROMPT_JA, (crate::config::CommitFormat::Commitlint, Language::Japanese) => {
COMMITLINT_SYSTEM_PROMPT_JA
}
(crate::config::CommitFormat::Commitlint, Language::Korean) => COMMITLINT_SYSTEM_PROMPT_KO, (crate::config::CommitFormat::Commitlint, Language::Korean) => COMMITLINT_SYSTEM_PROMPT_KO,
(crate::config::CommitFormat::Commitlint, Language::Spanish) => COMMITLINT_SYSTEM_PROMPT_ES, (crate::config::CommitFormat::Commitlint, Language::Spanish) => COMMITLINT_SYSTEM_PROMPT_ES,
(crate::config::CommitFormat::Commitlint, Language::French) => COMMITLINT_SYSTEM_PROMPT_FR, (crate::config::CommitFormat::Commitlint, Language::French) => COMMITLINT_SYSTEM_PROMPT_FR,
@@ -502,8 +707,7 @@ const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ZH: &str = r#"你是一个生成符合 C
4. 不要大写首字母 4. 不要大写首字母
5. 结尾不要句号 5. 结尾不要句号
6. 如果更改特定于模块/组件,请包含作用域 6. 如果更改特定于模块/组件,请包含作用域
7. 仅输出提交消息,不要输出其他内容。
仅输出提交消息,不要输出其他内容。
"#; "#;
const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_JA: &str = r#"あなたはConventional Commits仕様に従ったコミットメッセージを生成するアシスタントです。 const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_JA: &str = r#"あなたはConventional Commits仕様に従ったコミットメッセージを生成するアシスタントです。
@@ -532,8 +736,7 @@ const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_JA: &str = r#"あなたはConventional C
4. 先頭を大文字にしない 4. 先頭を大文字にしない
5. 最後にピリオドを付けない 5. 最後にピリオドを付けない
6. 変更がモジュール/コンポーネントに固有の場合はスコープを含める 6. 変更がモジュール/コンポーネントに固有の場合はスコープを含める
7. コミットメッセージのみを出力し、それ以外は出力しないでください。
コミットメッセージのみを出力し、それ以外は出力しないでください。
"#; "#;
const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_KO: &str = r#"당신은 Conventional Commits 사양에 따른 커밋 메시지를 생성하는 도우미입니다. const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_KO: &str = r#"당신은 Conventional Commits 사양에 따른 커밋 메시지를 생성하는 도우미입니다.
@@ -562,8 +765,7 @@ const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_KO: &str = r#"당신은 Conventional Com
4. 첫 글자 대문자화하지 않음 4. 첫 글자 대문자화하지 않음
5. 끝에 마침표 사용하지 않음 5. 끝에 마침표 사용하지 않음
6. 변경 사항이 모듈/구성 요소에 특정한 경우 범위 포함 6. 변경 사항이 모듈/구성 요소에 특정한 경우 범위 포함
7. 커밋 메시지만 출력하고 다른 내용은 출력하지 마세요.
커밋 메시지만 출력하고 다른 내용은 출력하지 마세요.
"#; "#;
const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ES: &str = r#"Eres un asistente que genera mensajes de commit siguiendo la especificación Conventional Commits. const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ES: &str = r#"Eres un asistente que genera mensajes de commit siguiendo la especificación Conventional Commits.
@@ -592,8 +794,7 @@ Reglas:
4. No capitalices la primera letra 4. No capitalices la primera letra
5. Sin punto al final 5. Sin punto al final
6. Incluye alcance si el cambio es específico de un módulo/componente 6. Incluye alcance si el cambio es específico de un módulo/componente
7. Genera SOLO el mensaje de commit, nada más.
Genera SOLO el mensaje de commit, nada más.
"#; "#;
const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_FR: &str = r#"Vous êtes un assistant qui génère des messages de commit suivant la spécification Conventional Commits. const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_FR: &str = r#"Vous êtes un assistant qui génère des messages de commit suivant la spécification Conventional Commits.
@@ -622,8 +823,7 @@ Règles:
4. Ne capitalisez pas la première lettre 4. Ne capitalisez pas la première lettre
5. Pas de point à la fin 5. Pas de point à la fin
6. Incluez la portée si le changement est spécifique à un module/composant 6. Incluez la portée si le changement est spécifique à un module/composant
7. Générez SEULEMENT le message de commit, rien d'autre.
Générez SEULEMENT le message de commit, rien d'autre.
"#; "#;
const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_DE: &str = r#"Sie sind ein Assistent, der Commit-Nachrichten gemäß der Conventional Commits-Spezifikation generiert. const CONVENTIONAL_COMMIT_SYSTEM_PROMPT_DE: &str = r#"Sie sind ein Assistent, der Commit-Nachrichten gemäß der Conventional Commits-Spezifikation generiert.
@@ -652,8 +852,7 @@ Regeln:
4. Großschreiben Sie den ersten Buchstaben nicht 4. Großschreiben Sie den ersten Buchstaben nicht
5. Kein Punkt am Ende 5. Kein Punkt am Ende
6. Fügen Sie einen Bereich ein, wenn die Änderung spezifisch für ein Modul/Komponente ist 6. Fügen Sie einen Bereich ein, wenn die Änderung spezifisch für ein Modul/Komponente ist
7. Geben Sie NUR die Commit-Nachricht aus, nichts anderes.
Geben Sie NUR die Commit-Nachricht aus, nichts anderes.
"#; "#;
const COMMITLINT_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates commit messages following @commitlint/config-conventional. const COMMITLINT_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates commit messages following @commitlint/config-conventional.
@@ -670,8 +869,7 @@ Rules:
3. Subject should be 4-100 characters 3. Subject should be 4-100 characters
4. Use imperative mood 4. Use imperative mood
5. Be concise but descriptive 5. Be concise but descriptive
6. Output ONLY the commit message, nothing else.
Output ONLY the commit message, nothing else.
"#; "#;
const COMMITLINT_SYSTEM_PROMPT_ZH: &str = r#"你是一个生成符合 @commitlint/config-conventional 规范的提交消息的助手。 const COMMITLINT_SYSTEM_PROMPT_ZH: &str = r#"你是一个生成符合 @commitlint/config-conventional 规范的提交消息的助手。
@@ -688,8 +886,7 @@ const COMMITLINT_SYSTEM_PROMPT_ZH: &str = r#"你是一个生成符合 @commitlin
3. 主题应为 4-100 个字符 3. 主题应为 4-100 个字符
4. 使用祈使语气 4. 使用祈使语气
5. 简洁但描述性强 5. 简洁但描述性强
6. 仅输出提交消息,不要输出其他额外内容。
仅输出提交消息,不要输出其他内容。
"#; "#;
const COMMITLINT_SYSTEM_PROMPT_JA: &str = r#"あなたは@commitlint/config-conventionalに従ったコミットメッセージを生成するアシスタントです。 const COMMITLINT_SYSTEM_PROMPT_JA: &str = r#"あなたは@commitlint/config-conventionalに従ったコミットメッセージを生成するアシスタントです。
@@ -706,8 +903,7 @@ git diffを分析し、コミットメッセージを生成してください。
3. 件名は4-100文字である必要があります 3. 件名は4-100文字である必要があります
4. 命令形を使用してください 4. 命令形を使用してください
5. 簡潔ですが説明的であること 5. 簡潔ですが説明的であること
6. コミットメッセージのみを出力し、それ以外は出力しないでください。
コミットメッセージのみを出力し、それ以外は出力しないでください。
"#; "#;
const COMMITLINT_SYSTEM_PROMPT_KO: &str = r#"당신은 @commitlint/config-conventional에 따른 커밋 메시지를 생성하는 도우미입니다. const COMMITLINT_SYSTEM_PROMPT_KO: &str = r#"당신은 @commitlint/config-conventional에 따른 커밋 메시지를 생성하는 도우미입니다.
@@ -724,8 +920,7 @@ git diff를 분석하고 커밋 메시지를 생성하세요.
3. 제목은 4-100자여야 합니다 3. 제목은 4-100자여야 합니다
4. 명령형을 사용하세요 4. 명령형을 사용하세요
5. 간결하지만 설명적이어야 합니다 5. 간결하지만 설명적이어야 합니다
6. 커밋 메시지만 출력하고 다른 내용은 출력하지 마세요.
커밋 메시지만 출력하고 다른 내용은 출력하지 마세요.
"#; "#;
const COMMITLINT_SYSTEM_PROMPT_ES: &str = r#"Eres un asistente que genera mensajes de commit siguiendo @commitlint/config-conventional. const COMMITLINT_SYSTEM_PROMPT_ES: &str = r#"Eres un asistente que genera mensajes de commit siguiendo @commitlint/config-conventional.
@@ -742,8 +937,7 @@ Reglas:
3. El asunto debe tener 4-100 caracteres 3. El asunto debe tener 4-100 caracteres
4. Usa modo imperativo 4. Usa modo imperativo
5. Sé conciso pero descriptivo 5. Sé conciso pero descriptivo
6. Genera SOLO el mensaje de commit, nada más.
Genera SOLO el mensaje de commit, nada más.
"#; "#;
const COMMITLINT_SYSTEM_PROMPT_FR: &str = r#"Vous êtes un assistant qui génère des messages de commit suivant @commitlint/config-conventional. const COMMITLINT_SYSTEM_PROMPT_FR: &str = r#"Vous êtes un assistant qui génère des messages de commit suivant @commitlint/config-conventional.
@@ -760,8 +954,7 @@ Règles:
3. Le sujet doit avoir 4-100 caractères 3. Le sujet doit avoir 4-100 caractères
4. Utilisez le mode impératif 4. Utilisez le mode impératif
5. Soyez concis mais descriptif 5. Soyez concis mais descriptif
6. Générez SEULEMENT le message de commit, rien d'autre.
Générez SEULEMENT le message de commit, rien d'autre.
"#; "#;
const COMMITLINT_SYSTEM_PROMPT_DE: &str = r#"Sie sind ein Assistent, der Commit-Nachrichten gemäß @commitlint/config-conventional generiert. const COMMITLINT_SYSTEM_PROMPT_DE: &str = r#"Sie sind ein Assistent, der Commit-Nachrichten gemäß @commitlint/config-conventional generiert.
@@ -778,8 +971,7 @@ Regeln:
3. Der Betreff sollte 4-100 Zeichen haben 3. Der Betreff sollte 4-100 Zeichen haben
4. Verwenden Sie den Imperativ 4. Verwenden Sie den Imperativ
5. Seien Sie prägnant aber beschreibend 5. Seien Sie prägnant aber beschreibend
6. Geben Sie NUR die Commit-Nachricht aus, nichts anderes.
Geben Sie NUR die Commit-Nachricht aus, nichts anderes.
"#; "#;
const TAG_MESSAGE_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates git tag annotation messages. const TAG_MESSAGE_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates git tag annotation messages.
@@ -1012,3 +1204,10 @@ Gruppieren Sie Commits nach:
Formatieren Sie in Markdown mit geeigneten Überschriften und Aufzählungspunkten. Formatieren Sie in Markdown mit geeigneten Überschriften und Aufzählungspunkten.
"#; "#;
/// Test LLM connection
pub async fn test_connection(manager: &crate::config::manager::ConfigManager) -> Result<String> {
let client = LlmClient::from_config(manager).await?;
let response = client.provider.generate("Say 'Hello, World!'").await?;
Ok(response)
}

View File

@@ -1,4 +1,4 @@
use super::{create_http_client, LlmProvider}; use super::{LlmProvider, create_http_client};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -9,6 +9,9 @@ pub struct OllamaClient {
base_url: String, base_url: String,
model: String, model: String,
client: reqwest::Client, client: reqwest::Client,
max_tokens: u32,
temperature: f32,
top_p: Option<f32>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@@ -47,69 +50,88 @@ struct ModelInfo {
impl OllamaClient { impl OllamaClient {
/// Create new Ollama client /// Create new Ollama client
pub fn new(base_url: &str, model: &str) -> Self { pub fn new(base_url: &str, model: &str) -> Self {
let client = create_http_client(Duration::from_secs(120)) let client =
.expect("Failed to create HTTP client"); create_http_client(Duration::from_secs(120)).expect("Failed to create HTTP client");
Self { Self {
base_url: base_url.trim_end_matches('/').to_string(), base_url: base_url.trim_end_matches('/').to_string(),
model: model.to_string(), model: model.to_string(),
client, client,
max_tokens: 500,
temperature: 0.7,
top_p: None,
} }
} }
/// Set timeout /// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Self { pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.client = create_http_client(timeout) self.client = create_http_client(timeout).expect("Failed to create HTTP client");
.expect("Failed to create HTTP client"); self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self self
} }
/// List available models /// List available models
pub async fn list_models(&self) -> Result<Vec<String>> { pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/api/tags", self.base_url); let url = format!("{}/api/tags", self.base_url);
let response = self.client let response = self
.client
.get(&url) .get(&url)
.send() .send()
.await .await
.context("Failed to list Ollama models")?; .context("Failed to list Ollama models")?;
if !response.status().is_success() { if !response.status().is_success() {
let status = response.status(); let status = response.status();
let text = response.text().await.unwrap_or_default(); let text = response.text().await.unwrap_or_default();
anyhow::bail!("Ollama API error: {} - {}", status, text); anyhow::bail!("Ollama API error: {} - {}", status, text);
} }
let result: ListModelsResponse = response let result: ListModelsResponse = response
.json() .json()
.await .await
.context("Failed to parse Ollama response")?; .context("Failed to parse Ollama response")?;
Ok(result.models.into_iter().map(|m| m.name).collect()) Ok(result.models.into_iter().map(|m| m.name).collect())
} }
/// Pull a model /// Pull a model
pub async fn pull_model(&self, model: &str) -> Result<()> { pub async fn pull_model(&self, model: &str) -> Result<()> {
let url = format!("{}/api/pull", self.base_url); let url = format!("{}/api/pull", self.base_url);
let request = serde_json::json!({ let request = serde_json::json!({
"name": model, "name": model,
"stream": false, "stream": false,
}); });
let response = self.client let response = self
.client
.post(&url) .post(&url)
.json(&request) .json(&request)
.send() .send()
.await .await
.context("Failed to pull Ollama model")?; .context("Failed to pull Ollama model")?;
if !response.status().is_success() { if !response.status().is_success() {
let status = response.status(); let status = response.status();
let text = response.text().await.unwrap_or_default(); let text = response.text().await.unwrap_or_default();
anyhow::bail!("Ollama pull error: {} - {}", status, text); anyhow::bail!("Ollama pull error: {} - {}", status, text);
} }
Ok(()) Ok(())
} }
@@ -130,48 +152,49 @@ impl LlmProvider for OllamaClient {
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> { async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let url = format!("{}/api/generate", self.base_url); let url = format!("{}/api/generate", self.base_url);
let system = if system.is_empty() { let system = if system.is_empty() {
None None
} else { } else {
Some(system.to_string()) Some(system.to_string())
}; };
let request = GenerateRequest { let request = GenerateRequest {
model: self.model.clone(), model: self.model.clone(),
prompt: user.to_string(), prompt: user.to_string(),
system, system,
stream: false, stream: false,
options: GenerationOptions { options: GenerationOptions {
temperature: Some(0.7), temperature: Some(self.temperature),
num_predict: Some(500), num_predict: Some(self.max_tokens),
}, },
}; };
let response = self.client let response = self
.client
.post(&url) .post(&url)
.json(&request) .json(&request)
.send() .send()
.await .await
.context("Failed to send request to Ollama")?; .context("Failed to send request to Ollama")?;
if !response.status().is_success() { if !response.status().is_success() {
let status = response.status(); let status = response.status();
let text = response.text().await.unwrap_or_default(); let text = response.text().await.unwrap_or_default();
anyhow::bail!("Ollama API error: {} - {}", status, text); anyhow::bail!("Ollama API error: {} - {}", status, text);
} }
let result: GenerateResponse = response let result: GenerateResponse = response
.json() .json()
.await .await
.context("Failed to parse Ollama response")?; .context("Failed to parse Ollama response")?;
Ok(result.response.trim().to_string()) Ok(result.response.trim().to_string())
} }
async fn is_available(&self) -> bool { async fn is_available(&self) -> bool {
let url = format!("{}/api/tags", self.base_url); let url = format!("{}/api/tags", self.base_url);
match self.client.get(&url).send().await { match self.client.get(&url).send().await {
Ok(response) => response.status().is_success(), Ok(response) => response.status().is_success(),
Err(_) => false, Err(_) => false,

View File

@@ -1,15 +1,23 @@
use super::{create_http_client, LlmProvider}; use super::thinking::ThinkingStateManager;
use anyhow::{bail, Context, Result}; use super::{LlmProvider, create_http_client};
use anyhow::{Context, Result, bail};
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
/// OpenAI API client /// OpenAI API client with o-series reasoning support
pub struct OpenAiClient { pub struct OpenAiClient {
base_url: String, base_url: String,
api_key: String, api_key: String,
model: String, model: String,
client: reqwest::Client, client: reqwest::Client,
thinking_enabled: bool,
reasoning_effort: Option<String>,
max_tokens: u32,
temperature: f32,
top_p: Option<f32>,
thinking_state: Option<Arc<ThinkingStateManager>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@@ -20,10 +28,14 @@ struct ChatCompletionRequest {
max_tokens: Option<u32>, max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>, temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_effort: Option<String>,
stream: bool, stream: bool,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, Clone)]
struct Message { struct Message {
role: String, role: String,
content: String, content: String,
@@ -39,6 +51,28 @@ struct Choice {
message: Message, message: Message,
} }
// --- Streaming response structures ---
#[derive(Debug, Deserialize)]
struct StreamChunk {
choices: Vec<StreamChoice>,
}
#[derive(Debug, Deserialize)]
struct StreamChoice {
delta: StreamDelta,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize, Default)]
struct StreamDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct ErrorResponse { struct ErrorResponse {
error: ApiError, error: ApiError,
@@ -55,57 +89,91 @@ impl OpenAiClient {
/// Create new OpenAI client /// Create new OpenAI client
pub fn new(base_url: &str, api_key: &str, model: &str) -> Result<Self> { pub fn new(base_url: &str, api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?; let client = create_http_client(Duration::from_secs(60))?;
Ok(Self { Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(), base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(), api_key: api_key.to_string(),
model: model.to_string(), model: model.to_string(),
client, client,
thinking_enabled: false,
reasoning_effort: None,
max_tokens: 500,
temperature: 0.7,
top_p: None,
thinking_state: None,
}) })
} }
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> { pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?; self.client = create_http_client(timeout)?;
Ok(self) Ok(self)
} }
/// List available models pub fn with_thinking(mut self, enabled: bool) -> Self {
self.thinking_enabled = enabled;
self
}
pub fn with_reasoning_effort(mut self, effort: Option<String>) -> Self {
self.reasoning_effort = effort;
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
self.thinking_state = Some(state);
self
}
pub async fn list_models(&self) -> Result<Vec<String>> { pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/models", self.base_url); let url = format!("{}/models", self.base_url);
let response = self.client let response = self
.client
.get(&url) .get(&url)
.header("Authorization", format!("Bearer {}", self.api_key)) .header("Authorization", format!("Bearer {}", self.api_key))
.send() .send()
.await .await
.context("Failed to list OpenAI models")?; .context("Failed to list OpenAI models")?;
if !response.status().is_success() { if !response.status().is_success() {
let status = response.status(); let status = response.status();
let text = response.text().await.unwrap_or_default(); let text = response.text().await.unwrap_or_default();
bail!("OpenAI API error: {} - {}", status, text); bail!("OpenAI API error: {} - {}", status, text);
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct ModelsResponse { struct ModelsResponse {
data: Vec<Model>, data: Vec<Model>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
struct Model { struct Model {
id: String, id: String,
} }
let result: ModelsResponse = response let result: ModelsResponse = response
.json() .json()
.await .await
.context("Failed to parse OpenAI response")?; .context("Failed to parse OpenAI response")?;
Ok(result.data.into_iter().map(|m| m.id).collect()) Ok(result.data.into_iter().map(|m| m.id).collect())
} }
/// Validate API key
pub async fn validate_key(&self) -> Result<bool> { pub async fn validate_key(&self) -> Result<bool> {
match self.list_models().await { match self.list_models().await {
Ok(_) => Ok(true), Ok(_) => Ok(true),
@@ -124,32 +192,30 @@ impl OpenAiClient {
#[async_trait] #[async_trait]
impl LlmProvider for OpenAiClient { impl LlmProvider for OpenAiClient {
async fn generate(&self, prompt: &str) -> Result<String> { async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![ let messages = vec![Message {
Message { role: "user".to_string(),
role: "user".to_string(), content: prompt.to_string(),
content: prompt.to_string(), }];
},
]; self.chat_completion_with_retry(messages).await
self.chat_completion(messages).await
} }
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> { async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![]; let mut messages = vec![];
if !system.is_empty() { if !system.is_empty() {
messages.push(Message { messages.push(Message {
role: "system".to_string(), role: "system".to_string(),
content: system.to_string(), content: system.to_string(),
}); });
} }
messages.push(Message { messages.push(Message {
role: "user".to_string(), role: "user".to_string(),
content: user.to_string(), content: user.to_string(),
}); });
self.chat_completion(messages).await self.chat_completion_with_retry(messages).await
} }
async fn is_available(&self) -> bool { async fn is_available(&self) -> bool {
@@ -162,18 +228,63 @@ impl LlmProvider for OpenAiClient {
} }
impl OpenAiClient { impl OpenAiClient {
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
let mut last_error = None;
for attempt in 1..=3 {
match self.chat_completion(messages.clone()).await {
Ok(result) => return Ok(result),
Err(e) => {
let err_msg = e.to_string();
let is_retryable = err_msg.contains("timeout")
|| err_msg.contains("connection")
|| err_msg.contains("temporary")
|| err_msg.contains("5")
&& (err_msg.contains("500")
|| err_msg.contains("502")
|| err_msg.contains("503")
|| err_msg.contains("504"));
if !is_retryable || attempt == 3 {
last_error = Some(e);
break;
}
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
}
}
}
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
}
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> { async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
if self.thinking_enabled {
self.streaming_chat_completion(messages).await
} else {
self.non_streaming_chat_completion(messages).await
}
}
async fn non_streaming_chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url); let url = format!("{}/chat/completions", self.base_url);
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
model: self.model.clone(), model: self.model.clone(),
messages, messages,
max_tokens: Some(500), max_tokens: Some(self.max_tokens),
temperature: Some(0.7), temperature: Some(self.temperature),
top_p: self.top_p,
reasoning_effort: if is_reasoning_model(&self.model) {
Some("none".to_string())
} else {
None
},
stream: false, stream: false,
}; };
let response = self.client let response = self
.client
.post(&url) .post(&url)
.header("Authorization", format!("Bearer {}", self.api_key)) .header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
@@ -181,31 +292,166 @@ impl OpenAiClient {
.send() .send()
.await .await
.context("Failed to send request to OpenAI")?; .context("Failed to send request to OpenAI")?;
let status = response.status(); let status = response.status();
if !status.is_success() { if !status.is_success() {
let text = response.text().await.unwrap_or_default(); let text = response.text().await.unwrap_or_default();
// Try to parse error
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) { if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!("OpenAI API error: {} ({})", error.error.message, error.error.error_type); bail!(
"OpenAI API error: {} ({})",
error.error.message,
error.error.error_type
);
} }
bail!("OpenAI API error: {} - {}", status, text); bail!("OpenAI API error: {} - {}", status, text);
} }
let result: ChatCompletionResponse = response let result: ChatCompletionResponse = response
.json() .json()
.await .await
.context("Failed to parse OpenAI response")?; .context("Failed to parse OpenAI response")?;
result.choices result
.choices
.into_iter() .into_iter()
.next() .next()
.map(|c| c.message.content.trim().to_string()) .map(|c| c.message.content.trim().to_string())
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI")) .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
} }
/// Streaming request for reasoning mode, filters reasoning_content from output
async fn streaming_chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
// For reasoning/thinking mode, omit temperature and top_p
let request = ChatCompletionRequest {
model: self.model.clone(),
messages,
max_tokens: Some(self.max_tokens),
temperature: None,
top_p: None,
reasoning_effort: self.reasoning_effort.clone(),
stream: true,
};
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.json(&request)
.send()
.await
.context("Failed to send streaming request to OpenAI")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"OpenAI API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("OpenAI API error: {} - {}", status, text);
}
let mut content_buffer = String::new();
let mut has_reasoning = false;
let mut has_content = false;
let thinking_state = self.thinking_state.as_ref();
let mut byte_stream = response.bytes_stream();
let mut line_buffer = String::new();
use futures_util::StreamExt;
while let Some(chunk) = byte_stream.next().await {
let chunk = chunk.context("Failed to read streaming response chunk")?;
let chunk_str =
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
line_buffer.push_str(&chunk_str);
while let Some(line_end) = line_buffer.find('\n') {
let line = line_buffer[..line_end].trim().to_string();
line_buffer = line_buffer[line_end + 1..].to_string();
if line.is_empty() {
continue;
}
if line == "data: [DONE]" {
break;
}
if let Some(json_str) = line.strip_prefix("data: ") {
if let Ok(chunk) = serde_json::from_str::<StreamChunk>(json_str) {
for choice in &chunk.choices {
// Handle reasoning_content (o-series)
if let Some(ref reasoning) = choice.delta.reasoning_content
&& !reasoning.is_empty()
{
if !has_reasoning {
has_reasoning = true;
if let Some(state) = thinking_state {
state.start_thinking();
}
}
continue;
}
// Handle content
if let Some(ref content) = choice.delta.content
&& !content.is_empty()
{
if has_reasoning
&& !has_content
&& let Some(state) = thinking_state
{
state.end_thinking();
}
has_content = true;
content_buffer.push_str(content);
}
}
}
}
}
}
if let Some(state) = thinking_state {
state.end_thinking();
}
let result = content_buffer.trim().to_string();
if result.is_empty() {
if has_reasoning && !has_content {
bail!(
"OpenAI returned reasoning content but no final answer. \
The model may have entered an incomplete reasoning state. \
Please try again or disable thinking mode."
);
}
bail!(
"No response from OpenAI. \
If thinking mode is enabled, try disabling it or ensure the model supports reasoning."
);
}
Ok(result)
}
} }
/// Azure OpenAI client (extends OpenAI with Azure-specific config) /// Azure OpenAI client (extends OpenAI with Azure-specific config)
@@ -215,24 +461,30 @@ pub struct AzureOpenAiClient {
deployment: String, deployment: String,
api_version: String, api_version: String,
client: reqwest::Client, client: reqwest::Client,
thinking_enabled: bool,
reasoning_effort: Option<String>,
max_tokens: u32,
temperature: f32,
top_p: Option<f32>,
thinking_state: Option<Arc<ThinkingStateManager>>,
} }
impl AzureOpenAiClient { impl AzureOpenAiClient {
/// Create new Azure OpenAI client pub fn new(endpoint: &str, api_key: &str, deployment: &str, api_version: &str) -> Result<Self> {
pub fn new(
endpoint: &str,
api_key: &str,
deployment: &str,
api_version: &str,
) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?; let client = create_http_client(Duration::from_secs(60))?;
Ok(Self { Ok(Self {
endpoint: endpoint.trim_end_matches('/').to_string(), endpoint: endpoint.trim_end_matches('/').to_string(),
api_key: api_key.to_string(), api_key: api_key.to_string(),
deployment: deployment.to_string(), deployment: deployment.to_string(),
api_version: api_version.to_string(), api_version: api_version.to_string(),
client, client,
thinking_enabled: false,
reasoning_effort: None,
max_tokens: 500,
temperature: 0.7,
top_p: None,
thinking_state: None,
}) })
} }
@@ -241,16 +493,19 @@ impl AzureOpenAiClient {
"{}/openai/deployments/{}/chat/completions?api-version={}", "{}/openai/deployments/{}/chat/completions?api-version={}",
self.endpoint, self.deployment, self.api_version self.endpoint, self.deployment, self.api_version
); );
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
model: self.deployment.clone(), model: self.deployment.clone(),
messages, messages,
max_tokens: Some(500), max_tokens: Some(self.max_tokens),
temperature: Some(0.7), temperature: Some(self.temperature),
top_p: self.top_p,
reasoning_effort: self.reasoning_effort.clone(),
stream: false, stream: false,
}; };
let response = self.client let response = self
.client
.post(&url) .post(&url)
.header("api-key", &self.api_key) .header("api-key", &self.api_key)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
@@ -258,22 +513,24 @@ impl AzureOpenAiClient {
.send() .send()
.await .await
.context("Failed to send request to Azure OpenAI")?; .context("Failed to send request to Azure OpenAI")?;
if !response.status().is_success() { if !response.status().is_success() {
let status = response.status(); let status = response.status();
let text = response.text().await.unwrap_or_default(); let text = response.text().await.unwrap_or_default();
bail!("Azure OpenAI API error: {} - {}", status, text); bail!("Azure OpenAI API error: {} - {}", status, text);
} }
let result: ChatCompletionResponse = response let result: ChatCompletionResponse = response
.json() .json()
.await .await
.context("Failed to parse Azure OpenAI response")?; .context("Failed to parse Azure OpenAI response")?;
result.choices result
.choices
.into_iter() .into_iter()
.next() .next()
.map(|c| c.message.content.trim().to_string()) .map(|c| c.message.content.trim().to_string())
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI")) .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))
} }
} }
@@ -281,41 +538,38 @@ impl AzureOpenAiClient {
#[async_trait] #[async_trait]
impl LlmProvider for AzureOpenAiClient { impl LlmProvider for AzureOpenAiClient {
async fn generate(&self, prompt: &str) -> Result<String> { async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![ let messages = vec![Message {
Message { role: "user".to_string(),
role: "user".to_string(), content: prompt.to_string(),
content: prompt.to_string(), }];
},
];
self.chat_completion(messages).await self.chat_completion(messages).await
} }
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> { async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![]; let mut messages = vec![];
if !system.is_empty() { if !system.is_empty() {
messages.push(Message { messages.push(Message {
role: "system".to_string(), role: "system".to_string(),
content: system.to_string(), content: system.to_string(),
}); });
} }
messages.push(Message { messages.push(Message {
role: "user".to_string(), role: "user".to_string(),
content: user.to_string(), content: user.to_string(),
}); });
self.chat_completion(messages).await self.chat_completion(messages).await
} }
async fn is_available(&self) -> bool { async fn is_available(&self) -> bool {
// Simple check - try to make a minimal request
let url = format!( let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version={}", "{}/openai/deployments/{}/chat/completions?api-version={}",
self.endpoint, self.deployment, self.api_version self.endpoint, self.deployment, self.api_version
); );
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
model: self.deployment.clone(), model: self.deployment.clone(),
messages: vec![Message { messages: vec![Message {
@@ -324,10 +578,13 @@ impl LlmProvider for AzureOpenAiClient {
}], }],
max_tokens: Some(5), max_tokens: Some(5),
temperature: Some(0.0), temperature: Some(0.0),
top_p: None,
reasoning_effort: None,
stream: false, stream: false,
}; };
match self.client match self
.client
.post(&url) .post(&url)
.header("api-key", &self.api_key) .header("api-key", &self.api_key)
.json(&request) .json(&request)
@@ -343,3 +600,60 @@ impl LlmProvider for AzureOpenAiClient {
"azure-openai" "azure-openai"
} }
} }
/// Available OpenAI models (including o-series with reasoning)
pub const OPENAI_MODELS: &[&str] = &[
"o4-mini",
"o3",
"o3-mini",
"o1",
"o1-mini",
"o1-pro",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4",
"gpt-3.5-turbo",
];
pub fn is_valid_model(model: &str) -> bool {
OPENAI_MODELS.contains(&model)
}
fn is_reasoning_model(model: &str) -> bool {
model.starts_with("o")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_validation_o_series() {
assert!(is_valid_model("o4-mini"));
assert!(is_valid_model("o3"));
assert!(is_valid_model("o1"));
assert!(is_valid_model("gpt-4o"));
assert!(is_valid_model("gpt-3.5-turbo"));
assert!(!is_valid_model("invalid-model"));
}
#[test]
fn test_stream_delta_reasoning_parsing() {
let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#;
let delta: StreamDelta = serde_json::from_str(json).unwrap();
assert!(delta.content.is_none());
assert_eq!(delta.reasoning_content, Some("Let me think...".to_string()));
}
#[test]
fn test_stream_delta_content_parsing() {
let json = r#"{"content":"Hello","reasoning_content":null}"#;
let delta: StreamDelta = serde_json::from_str(json).unwrap();
assert_eq!(delta.content, Some("Hello".to_string()));
assert!(delta.reasoning_content.is_none());
}
}

View File

@@ -1,257 +1,286 @@
use super::{create_http_client, LlmProvider}; use super::{LlmProvider, create_http_client};
use anyhow::{bail, Context, Result}; use anyhow::{Context, Result, bail};
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::time::Duration; use std::time::Duration;
/// OpenRouter API client /// OpenRouter API client
pub struct OpenRouterClient { pub struct OpenRouterClient {
base_url: String, base_url: String,
api_key: String, api_key: String,
model: String, model: String,
client: reqwest::Client, client: reqwest::Client,
} max_tokens: u32,
temperature: f32,
#[derive(Debug, Serialize)] top_p: Option<f32>,
struct ChatCompletionRequest { }
model: String,
messages: Vec<Message>, #[derive(Debug, Serialize)]
#[serde(skip_serializing_if = "Option::is_none")] struct ChatCompletionRequest {
max_tokens: Option<u32>, model: String,
#[serde(skip_serializing_if = "Option::is_none")] messages: Vec<Message>,
temperature: Option<f32>, #[serde(skip_serializing_if = "Option::is_none")]
stream: bool, max_tokens: Option<u32>,
} #[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[derive(Debug, Serialize, Deserialize)] stream: bool,
struct Message { }
role: String,
content: String, #[derive(Debug, Serialize, Deserialize)]
} struct Message {
role: String,
#[derive(Debug, Deserialize)] content: String,
struct ChatCompletionResponse { }
choices: Vec<Choice>,
} #[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
#[derive(Debug, Deserialize)] choices: Vec<Choice>,
struct Choice { }
message: Message,
} #[derive(Debug, Deserialize)]
struct Choice {
#[derive(Debug, Deserialize)] message: Message,
struct ErrorResponse { }
error: ApiError,
} #[derive(Debug, Deserialize)]
struct ErrorResponse {
#[derive(Debug, Deserialize)] error: ApiError,
struct ApiError { }
message: String,
#[serde(rename = "type")] #[derive(Debug, Deserialize)]
error_type: String, struct ApiError {
} message: String,
#[serde(rename = "type")]
impl OpenRouterClient { error_type: String,
/// Create new OpenRouter client }
pub fn new(api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?; impl OpenRouterClient {
/// Create new OpenRouter client
Ok(Self { pub fn new(api_key: &str, model: &str) -> Result<Self> {
base_url: "https://openrouter.ai/api/v1".to_string(), let client = create_http_client(Duration::from_secs(60))?;
api_key: api_key.to_string(),
model: model.to_string(), Ok(Self {
client, base_url: "https://openrouter.ai/api/v1".to_string(),
}) api_key: api_key.to_string(),
} model: model.to_string(),
client,
/// Create with custom base URL max_tokens: 500,
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> { temperature: 0.7,
let client = create_http_client(Duration::from_secs(60))?; top_p: None,
})
Ok(Self { }
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(), /// Create with custom base URL
model: model.to_string(), pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
client, let client = create_http_client(Duration::from_secs(60))?;
})
} Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
/// Set timeout api_key: api_key.to_string(),
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> { model: model.to_string(),
self.client = create_http_client(timeout)?; client,
Ok(self) max_tokens: 500,
} temperature: 0.7,
top_p: None,
/// List available models })
pub async fn list_models(&self) -> Result<Vec<String>> { }
let url = format!("{}/models", self.base_url);
/// Set timeout
let response = self.client pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
.get(&url) self.client = create_http_client(timeout)?;
.header("Authorization", format!("Bearer {}", self.api_key)) Ok(self)
.header("HTTP-Referer", "https://quicommit.dev") }
.header("X-Title", "QuiCommit")
.send() pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
.await self.max_tokens = max_tokens;
.context("Failed to list OpenRouter models")?; self
}
if !response.status().is_success() {
let status = response.status(); pub fn with_temperature(mut self, temperature: f32) -> Self {
let text = response.text().await.unwrap_or_default(); self.temperature = temperature;
bail!("OpenRouter API error: {} - {}", status, text); self
} }
#[derive(Deserialize)] pub fn with_top_p(mut self, top_p: f32) -> Self {
struct ModelsResponse { self.top_p = Some(top_p);
data: Vec<Model>, self
} }
#[derive(Deserialize)] /// List available models
struct Model { pub async fn list_models(&self) -> Result<Vec<String>> {
id: String, let url = format!("{}/models", self.base_url);
}
let response = self
let result: ModelsResponse = response .client
.json() .get(&url)
.await .header("Authorization", format!("Bearer {}", self.api_key))
.context("Failed to parse OpenRouter response")?; .header("HTTP-Referer", "https://quicommit.dev")
.header("X-Title", "QuiCommit")
Ok(result.data.into_iter().map(|m| m.id).collect()) .send()
} .await
.context("Failed to list OpenRouter models")?;
/// Validate API key
pub async fn validate_key(&self) -> Result<bool> { if !response.status().is_success() {
match self.list_models().await { let status = response.status();
Ok(_) => Ok(true), let text = response.text().await.unwrap_or_default();
Err(e) => { bail!("OpenRouter API error: {} - {}", status, text);
let err_str = e.to_string(); }
if err_str.contains("401") || err_str.contains("Unauthorized") {
Ok(false) #[derive(Deserialize)]
} else { struct ModelsResponse {
Err(e) data: Vec<Model>,
} }
}
} #[derive(Deserialize)]
} struct Model {
} id: String,
}
#[async_trait]
impl LlmProvider for OpenRouterClient { let result: ModelsResponse = response
async fn generate(&self, prompt: &str) -> Result<String> { .json()
let messages = vec![ .await
Message { .context("Failed to parse OpenRouter response")?;
role: "user".to_string(),
content: prompt.to_string(), Ok(result.data.into_iter().map(|m| m.id).collect())
}, }
];
/// Validate API key
self.chat_completion(messages).await pub async fn validate_key(&self) -> Result<bool> {
} match self.list_models().await {
Ok(_) => Ok(true),
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> { Err(e) => {
let mut messages = vec![]; let err_str = e.to_string();
if err_str.contains("401") || err_str.contains("Unauthorized") {
if !system.is_empty() { Ok(false)
messages.push(Message { } else {
role: "system".to_string(), Err(e)
content: system.to_string(), }
}); }
} }
}
messages.push(Message { }
role: "user".to_string(),
content: user.to_string(), #[async_trait]
}); impl LlmProvider for OpenRouterClient {
async fn generate(&self, prompt: &str) -> Result<String> {
self.chat_completion(messages).await let messages = vec![Message {
} role: "user".to_string(),
content: prompt.to_string(),
async fn is_available(&self) -> bool { }];
self.validate_key().await.unwrap_or(false)
} self.chat_completion(messages).await
}
fn name(&self) -> &str {
"openrouter" async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
} let mut messages = vec![];
}
if !system.is_empty() {
impl OpenRouterClient { messages.push(Message {
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> { role: "system".to_string(),
let url = format!("{}/chat/completions", self.base_url); content: system.to_string(),
});
let request = ChatCompletionRequest { }
model: self.model.clone(),
messages, messages.push(Message {
max_tokens: Some(500), role: "user".to_string(),
temperature: Some(0.7), content: user.to_string(),
stream: false, });
};
self.chat_completion(messages).await
let response = self.client }
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key)) async fn is_available(&self) -> bool {
.header("Content-Type", "application/json") self.validate_key().await.unwrap_or(false)
.header("HTTP-Referer", "https://quicommit.dev") }
.header("X-Title", "QuiCommit")
.json(&request) fn name(&self) -> &str {
.send() "openrouter"
.await }
.context("Failed to send request to OpenRouter")?; }
let status = response.status(); impl OpenRouterClient {
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
if !status.is_success() { let url = format!("{}/chat/completions", self.base_url);
let text = response.text().await.unwrap_or_default();
let request = ChatCompletionRequest {
// Try to parse error model: self.model.clone(),
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) { messages,
bail!("OpenRouter API error: {} ({})", error.error.message, error.error.error_type); max_tokens: Some(self.max_tokens),
} temperature: Some(self.temperature),
stream: false,
bail!("OpenRouter API error: {} - {}", status, text); };
}
let response = self
let result: ChatCompletionResponse = response .client
.json() .post(&url)
.await .header("Authorization", format!("Bearer {}", self.api_key))
.context("Failed to parse OpenRouter response")?; .header("Content-Type", "application/json")
.header("HTTP-Referer", "https://quicommit.dev")
result.choices .header("X-Title", "QuiCommit")
.into_iter() .json(&request)
.next() .send()
.map(|c| c.message.content.trim().to_string()) .await
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) .context("Failed to send request to OpenRouter")?;
}
} let status = response.status();
/// Popular OpenRouter models if !status.is_success() {
pub const OPENROUTER_MODELS: &[&str] = &[ let text = response.text().await.unwrap_or_default();
"openai/gpt-3.5-turbo",
"openai/gpt-4", // Try to parse error
"openai/gpt-4-turbo", if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
"anthropic/claude-3-opus", bail!(
"anthropic/claude-3-sonnet", "OpenRouter API error: {} ({})",
"anthropic/claude-3-haiku", error.error.message,
"google/gemini-pro", error.error.error_type
"meta-llama/llama-2-70b-chat", );
"mistralai/mixtral-8x7b-instruct", }
"01-ai/yi-34b-chat",
]; bail!("OpenRouter API error: {} - {}", status, text);
}
/// Check if a model name is valid
pub fn is_valid_model(_model: &str) -> bool { let result: ChatCompletionResponse = response
// Since OpenRouter supports many models, we'll allow any model name .json()
// but provide some popular ones as suggestions .await
true .context("Failed to parse OpenRouter response")?;
}
result
#[cfg(test)] .choices
mod tests { .into_iter()
use super::*; .next()
.map(|c| c.message.content.trim().to_string())
#[test] .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
fn test_model_validation() { }
assert!(is_valid_model("openai/gpt-4")); }
assert!(is_valid_model("custom/model"));
} /// Popular OpenRouter models
} pub const OPENROUTER_MODELS: &[&str] = &[
"openai/gpt-3.5-turbo",
"openai/gpt-4",
"openai/gpt-4-turbo",
"anthropic/claude-3-opus",
"anthropic/claude-3-sonnet",
"anthropic/claude-3-haiku",
"google/gemini-pro",
"meta-llama/llama-2-70b-chat",
"mistralai/mixtral-8x7b-instruct",
"01-ai/yi-34b-chat",
];
/// Check if a model name is valid
pub fn is_valid_model(_model: &str) -> bool {
// Since OpenRouter supports many models, we'll allow any model name
// but provide some popular ones as suggestions
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_validation() {
assert!(is_valid_model("openai/gpt-4"));
assert!(is_valid_model("custom/model"));
}
}

151
src/llm/thinking.rs Normal file
View File

@@ -0,0 +1,151 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
/// 统一的思考状态管理器,用于管理模型思考状态的显示与隐藏
pub struct ThinkingStateManager {
is_thinking: AtomicBool,
on_start: Option<Box<dyn Fn() + Send + Sync>>,
on_end: Option<Box<dyn Fn() + Send + Sync>>,
}
impl ThinkingStateManager {
pub fn new() -> Self {
Self {
is_thinking: AtomicBool::new(false),
on_start: None,
on_end: None,
}
}
/// 设置思考开始回调
pub fn on_thinking_start<F: Fn() + Send + Sync + 'static>(mut self, callback: F) -> Self {
self.on_start = Some(Box::new(callback));
self
}
/// 设置思考结束回调
pub fn on_thinking_end<F: Fn() + Send + Sync + 'static>(mut self, callback: F) -> Self {
self.on_end = Some(Box::new(callback));
self
}
/// 开始思考状态
pub fn start_thinking(&self) {
if !self.is_thinking.load(Ordering::SeqCst) {
self.is_thinking.store(true, Ordering::SeqCst);
if let Some(ref cb) = self.on_start {
cb();
}
}
}
/// 结束思考状态
pub fn end_thinking(&self) {
if self.is_thinking.load(Ordering::SeqCst) {
self.is_thinking.store(false, Ordering::SeqCst);
if let Some(ref cb) = self.on_end {
cb();
}
}
}
/// 当前是否处于思考状态
pub fn is_thinking(&self) -> bool {
self.is_thinking.load(Ordering::SeqCst)
}
}
impl Default for ThinkingStateManager {
fn default() -> Self {
Self::new()
}
}
/// 线程安全的思考状态管理器引用
pub type SharedThinkingState = Arc<ThinkingStateManager>;
/// 创建带有默认控制台输出的思考状态管理器
/// 在思考开始时打印 "thinking...",在思考结束时清除该标识
pub fn create_console_thinking_state() -> SharedThinkingState {
Arc::new(
ThinkingStateManager::new()
.on_thinking_start(|| {
eprint!("\rthinking...");
})
.on_thinking_end(|| {
eprint!("\r \r");
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[test]
fn test_thinking_state_transitions() {
let manager = ThinkingStateManager::new();
assert!(!manager.is_thinking());
manager.start_thinking();
assert!(manager.is_thinking());
manager.end_thinking();
assert!(!manager.is_thinking());
}
#[test]
fn test_thinking_idempotent_start() {
let manager = ThinkingStateManager::new();
manager.start_thinking();
manager.start_thinking(); // 重复调用不应触发回调两次
assert!(manager.is_thinking());
}
#[test]
fn test_thinking_idempotent_end() {
let manager = ThinkingStateManager::new();
manager.end_thinking(); // 未开始时结束不应触发问题
assert!(!manager.is_thinking());
}
#[test]
fn test_thinking_callbacks() {
let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
let manager = ThinkingStateManager::new().on_thinking_start(move || {
events_clone.lock().unwrap().push("start".to_string());
});
let events_clone2 = events.clone();
let manager = manager.on_thinking_end(move || {
events_clone2.lock().unwrap().push("end".to_string());
});
manager.start_thinking();
manager.end_thinking();
let recorded = events.lock().unwrap();
assert_eq!(recorded.len(), 2);
assert_eq!(recorded[0], "start");
assert_eq!(recorded[1], "end");
}
#[test]
fn test_create_console_thinking_state() {
let state = create_console_thinking_state();
assert!(!state.is_thinking());
state.start_thinking();
assert!(state.is_thinking());
state.end_thinking();
assert!(!state.is_thinking());
}
#[test]
fn test_default() {
let manager = ThinkingStateManager::default();
assert!(!manager.is_thinking());
}
}

View File

@@ -1,3 +1,5 @@
#![allow(dead_code)]
use anyhow::Result; use anyhow::Result;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use std::path::PathBuf; use std::path::PathBuf;
@@ -12,12 +14,12 @@ mod llm;
mod utils; mod utils;
use commands::{ use commands::{
changelog::ChangelogCommand, commit::CommitCommand, config::ConfigCommand, changelog::ChangelogCommand, commit::CommitCommand, config::ConfigCommand, init::InitCommand,
init::InitCommand, profile::ProfileCommand, tag::TagCommand, profile::ProfileCommand, tag::TagCommand,
}; };
/// QuiCommit - AI-powered Git assistant /// QuiCommit - AI-powered Git assistant
/// ///
/// A powerful tool that helps you generate conventional commits, tags, and changelogs /// A powerful tool that helps you generate conventional commits, tags, and changelogs
/// using AI (LLM APIs or local Ollama models). Manage multiple Git profiles for different /// using AI (LLM APIs or local Ollama models). Manage multiple Git profiles for different
/// work contexts seamlessly. /// work contexts seamlessly.
@@ -81,7 +83,7 @@ async fn main() -> Result<()> {
2 => "debug", 2 => "debug",
_ => "trace", _ => "trace",
}; };
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_env_filter(log_level) .with_env_filter(log_level)
.with_target(false) .with_target(false)

View File

@@ -1,9 +1,9 @@
use aes_gcm::{ use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce, Aes256Gcm, Nonce,
aead::{Aead, KeyInit},
}; };
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine}; use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
use rand::Rng; use rand::Rng;
use std::fs; use std::fs;
use std::path::Path; use std::path::Path;
@@ -18,63 +18,62 @@ pub fn encrypt(data: &[u8], password: &str) -> Result<String> {
rand::thread_rng().fill(&mut salt); rand::thread_rng().fill(&mut salt);
let mut nonce_bytes = [0u8; NONCE_LEN]; let mut nonce_bytes = [0u8; NONCE_LEN];
rand::thread_rng().fill(&mut nonce_bytes); rand::thread_rng().fill(&mut nonce_bytes);
let key = derive_key(password, &salt)?; let key = derive_key(password, &salt)?;
let cipher = Aes256Gcm::new_from_slice(&key) let cipher = Aes256Gcm::new_from_slice(&key).context("Failed to create cipher")?;
.context("Failed to create cipher")?;
let nonce = Nonce::from_slice(&nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes);
let encrypted = cipher let encrypted = cipher
.encrypt(nonce, data) .encrypt(nonce, data)
.map_err(|e| anyhow::anyhow!("Encryption failed: {:?}", e))?; .map_err(|e| anyhow::anyhow!("Encryption failed: {:?}", e))?;
// Combine salt + nonce + encrypted data // Combine salt + nonce + encrypted data
let mut result = Vec::with_capacity(SALT_LEN + NONCE_LEN + encrypted.len()); let mut result = Vec::with_capacity(SALT_LEN + NONCE_LEN + encrypted.len());
result.extend_from_slice(&salt); result.extend_from_slice(&salt);
result.extend_from_slice(&nonce_bytes); result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&encrypted); result.extend_from_slice(&encrypted);
Ok(BASE64.encode(&result)) Ok(BASE64.encode(&result))
} }
/// Decrypt data with password /// Decrypt data with password
pub fn decrypt(encrypted_data: &str, password: &str) -> Result<Vec<u8>> { pub fn decrypt(encrypted_data: &str, password: &str) -> Result<Vec<u8>> {
let data = BASE64.decode(encrypted_data) let data = BASE64
.decode(encrypted_data)
.context("Invalid base64 encoding")?; .context("Invalid base64 encoding")?;
if data.len() < SALT_LEN + NONCE_LEN { if data.len() < SALT_LEN + NONCE_LEN {
anyhow::bail!("Invalid encrypted data format"); anyhow::bail!("Invalid encrypted data format");
} }
let salt = &data[..SALT_LEN]; let salt = &data[..SALT_LEN];
let nonce_bytes = &data[SALT_LEN..SALT_LEN + NONCE_LEN]; let nonce_bytes = &data[SALT_LEN..SALT_LEN + NONCE_LEN];
let encrypted = &data[SALT_LEN + NONCE_LEN..]; let encrypted = &data[SALT_LEN + NONCE_LEN..];
let key = derive_key(password, salt)?; let key = derive_key(password, salt)?;
let cipher = Aes256Gcm::new_from_slice(&key) let cipher = Aes256Gcm::new_from_slice(&key).context("Failed to create cipher")?;
.context("Failed to create cipher")?;
let nonce = Nonce::from_slice(nonce_bytes); let nonce = Nonce::from_slice(nonce_bytes);
let decrypted = cipher let decrypted = cipher
.decrypt(nonce, encrypted) .decrypt(nonce, encrypted)
.map_err(|e| anyhow::anyhow!("Decryption failed: {:?}", e))?; .map_err(|e| anyhow::anyhow!("Decryption failed: {:?}", e))?;
Ok(decrypted) Ok(decrypted)
} }
/// Derive key from password using simple method /// Derive key from password using simple method
fn derive_key(password: &str, salt: &[u8]) -> Result<[u8; KEY_LEN]> { fn derive_key(password: &str, salt: &[u8]) -> Result<[u8; KEY_LEN]> {
use sha2::{Sha256, Digest}; use sha2::{Digest, Sha256};
let mut hasher = Sha256::new(); let mut hasher = Sha256::new();
hasher.update(salt); hasher.update(salt);
hasher.update(password.as_bytes()); hasher.update(password.as_bytes());
hasher.update(b"quicommit_key_derivation_v1"); hasher.update(b"quicommit_key_derivation_v1");
let hash = hasher.finalize(); let hash = hasher.finalize();
let mut key = [0u8; KEY_LEN]; let mut key = [0u8; KEY_LEN];
key.copy_from_slice(&hash[..KEY_LEN]); key.copy_from_slice(&hash[..KEY_LEN]);
Ok(key) Ok(key)
} }
@@ -97,7 +96,7 @@ pub fn decrypt_from_file(path: &Path, password: &str) -> Result<Vec<u8>> {
pub fn generate_token(length: usize) -> String { pub fn generate_token(length: usize) -> String {
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
(0..length) (0..length)
.map(|_| { .map(|_| {
let idx = rng.gen_range(0..CHARSET.len()); let idx = rng.gen_range(0..CHARSET.len());
@@ -122,10 +121,10 @@ mod tests {
fn test_encrypt_decrypt() { fn test_encrypt_decrypt() {
let data = b"Hello, World!"; let data = b"Hello, World!";
let password = "my_secret_password"; let password = "my_secret_password";
let encrypted = encrypt(data, password).unwrap(); let encrypted = encrypt(data, password).unwrap();
let decrypted = decrypt(&encrypted, password).unwrap(); let decrypted = decrypt(&encrypted, password).unwrap();
assert_eq!(data.to_vec(), decrypted); assert_eq!(data.to_vec(), decrypted);
} }
@@ -133,7 +132,7 @@ mod tests {
fn test_wrong_password() { fn test_wrong_password() {
let data = b"Hello, World!"; let data = b"Hello, World!";
let encrypted = encrypt(data, "correct_password").unwrap(); let encrypted = encrypt(data, "correct_password").unwrap();
assert!(decrypt(&encrypted, "wrong_password").is_err()); assert!(decrypt(&encrypted, "wrong_password").is_err());
} }
} }

View File

@@ -9,15 +9,12 @@ pub fn edit_content(initial_content: &str) -> Result<String> {
/// Edit file in user's default editor /// Edit file in user's default editor
pub fn edit_file(path: &Path) -> Result<String> { pub fn edit_file(path: &Path) -> Result<String> {
let content = fs::read_to_string(path) let content = fs::read_to_string(path).unwrap_or_default();
.unwrap_or_default();
let edited = edit::edit(&content).context("Failed to open editor")?;
let edited = edit::edit(&content)
.context("Failed to open editor")?; fs::write(path, &edited).with_context(|| format!("Failed to write file: {:?}", path))?;
fs::write(path, &edited)
.with_context(|| format!("Failed to write file: {:?}", path))?;
Ok(edited) Ok(edited)
} }
@@ -27,11 +24,10 @@ pub fn edit_temp(initial_content: &str, extension: &str) -> Result<String> {
.suffix(&format!(".{}", extension)) .suffix(&format!(".{}", extension))
.tempfile() .tempfile()
.context("Failed to create temp file")?; .context("Failed to create temp file")?;
let path = temp_file.path(); let path = temp_file.path();
fs::write(path, initial_content) fs::write(path, initial_content).context("Failed to write temp file")?;
.context("Failed to write temp file")?;
edit_file(path) edit_file(path)
} }
@@ -41,10 +37,10 @@ pub fn get_editor() -> String {
.or_else(|_| std::env::var("VISUAL")) .or_else(|_| std::env::var("VISUAL"))
.unwrap_or_else(|_| { .unwrap_or_else(|_| {
if cfg!(target_os = "windows") { if cfg!(target_os = "windows") {
if let Ok(code) = which::which("code") { if let Ok(_code) = which::which("code") {
return "code --wait".to_string(); return "code --wait".to_string();
} }
if let Ok(notepad) = which::which("notepad") { if let Ok(_notepad) = which::which("notepad") {
return "notepad".to_string(); return "notepad".to_string();
} }
"notepad".to_string() "notepad".to_string()
@@ -65,7 +61,6 @@ pub fn get_editor() -> String {
/// Check if editor is available /// Check if editor is available
pub fn check_editor() -> Result<()> { pub fn check_editor() -> Result<()> {
let editor = get_editor(); let editor = get_editor();
which::which(&editor) which::which(&editor).with_context(|| format!("Editor '{}' not found in PATH", editor))?;
.with_context(|| format!("Editor '{}' not found in PATH", editor))?;
Ok(()) Ok(())
} }

View File

@@ -10,7 +10,7 @@ pub fn format_conventional_commit(
breaking: bool, breaking: bool,
) -> String { ) -> String {
let mut message = String::new(); let mut message = String::new();
message.push_str(commit_type); message.push_str(commit_type);
if let Some(s) = scope { if let Some(s) = scope {
message.push_str(&format!("({})", s)); message.push_str(&format!("({})", s));
@@ -19,15 +19,15 @@ pub fn format_conventional_commit(
message.push('!'); message.push('!');
} }
message.push_str(&format!(": {}", description)); message.push_str(&format!(": {}", description));
if let Some(b) = body { if let Some(b) = body {
message.push_str(&format!("\n\n{}", b)); message.push_str(&format!("\n\n{}", b));
} }
if let Some(f) = footer { if let Some(f) = footer {
message.push_str(&format!("\n\n{}", f)); message.push_str(&format!("\n\n{}", f));
} }
message message
} }
@@ -41,27 +41,27 @@ pub fn format_commitlint_commit(
references: Option<&[&str]>, references: Option<&[&str]>,
) -> String { ) -> String {
let mut message = String::new(); let mut message = String::new();
message.push_str(commit_type); message.push_str(commit_type);
if let Some(s) = scope { if let Some(s) = scope {
message.push_str(&format!("({})", s)); message.push_str(&format!("({})", s));
} }
message.push_str(&format!(": {}", subject)); message.push_str(&format!(": {}", subject));
if let Some(refs) = references { if let Some(refs) = references {
for reference in refs { for reference in refs {
message.push_str(&format!(" #{}", reference)); message.push_str(&format!(" #{}", reference));
} }
} }
if let Some(b) = body { if let Some(b) = body {
message.push_str(&format!("\n\n{}", b)); message.push_str(&format!("\n\n{}", b));
} }
if let Some(f) = footer { if let Some(f) = footer {
message.push_str(&format!("\n\n{}", f)); message.push_str(&format!("\n\n{}", f));
} }
message message
} }
@@ -73,7 +73,7 @@ pub fn wrap_text(text: &str, width: usize) -> String {
/// Clean commit message (remove comments, extra whitespace) /// Clean commit message (remove comments, extra whitespace)
pub fn clean_message(message: &str) -> String { pub fn clean_message(message: &str) -> String {
let comment_regex = Regex::new(r"^#.*$").unwrap(); let comment_regex = Regex::new(r"^#.*$").unwrap();
message message
.lines() .lines()
.filter(|line| !comment_regex.is_match(line.trim())) .filter(|line| !comment_regex.is_match(line.trim()))
@@ -97,7 +97,7 @@ mod tests {
Some("Closes #123"), Some("Closes #123"),
false, false,
); );
assert!(msg.contains("feat(auth): add login functionality")); assert!(msg.contains("feat(auth): add login functionality"));
assert!(msg.contains("This adds OAuth2 login support.")); assert!(msg.contains("This adds OAuth2 login support."));
assert!(msg.contains("Closes #123")); assert!(msg.contains("Closes #123"));
@@ -113,7 +113,7 @@ mod tests {
Some("BREAKING CHANGE: response format changed"), Some("BREAKING CHANGE: response format changed"),
true, true,
); );
assert!(msg.starts_with("feat!: change API response format")); assert!(msg.starts_with("feat!: change API response format"));
} }
} }

350
src/utils/keyring.rs Normal file
View File

@@ -0,0 +1,350 @@
use anyhow::{Context, Result, bail};
use std::env;
const SERVICE_NAME: &str = "quicommit";
const ENV_API_KEY: &str = "QUICOMMIT_API_KEY";
const PAT_SERVICE_PREFIX: &str = "quicommit/pat";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KeyringStatus {
Available,
Unavailable,
}
pub struct KeyringManager {
status: KeyringStatus,
}
impl KeyringManager {
pub fn new() -> Self {
let status = Self::check_keyring_availability();
Self { status }
}
pub fn check_keyring_availability() -> KeyringStatus {
#[cfg(target_os = "windows")]
{
KeyringStatus::Available
}
#[cfg(target_os = "macos")]
{
KeyringStatus::Available
}
#[cfg(target_os = "linux")]
{
Self::check_linux_keyring()
}
#[cfg(not(any(target_os = "windows", target_os = "macos", target_os = "linux")))]
{
KeyringStatus::Unavailable
}
}
#[cfg(target_os = "linux")]
fn check_linux_keyring() -> KeyringStatus {
use std::path::Path;
let has_dbus = Path::new("/usr/bin/dbus-daemon").exists()
|| Path::new("/bin/dbus-daemon").exists()
|| env::var("DBUS_SESSION_BUS_ADDRESS").is_ok();
let has_keyring = Path::new("/usr/bin/gnome-keyring-daemon").exists()
|| Path::new("/usr/bin/gnome-keyring").exists()
|| Path::new("/usr/bin/kwalletd5").exists()
|| Path::new("/usr/bin/kwalletd6").exists()
|| env::var("SECRET_SERVICE").is_ok();
if has_dbus && has_keyring {
KeyringStatus::Available
} else {
KeyringStatus::Unavailable
}
}
pub fn status(&self) -> KeyringStatus {
self.status
}
pub fn is_available(&self) -> bool {
self.status == KeyringStatus::Available
}
pub fn store_api_key(&self, provider: &str, api_key: &str) -> Result<()> {
if !self.is_available() {
bail!("Keyring is not available on this system");
}
let entry = keyring::Entry::new(SERVICE_NAME, provider)
.context("Failed to create keyring entry")?;
entry
.set_password(api_key)
.context("Failed to store API key")?;
Ok(())
}
pub fn get_api_key(&self, provider: &str) -> Result<Option<String>> {
if let Ok(key) = env::var(ENV_API_KEY)
&& !key.is_empty()
{
return Ok(Some(key));
}
if !self.is_available() {
return Ok(None);
}
let entry = keyring::Entry::new(SERVICE_NAME, provider)
.context("Failed to create keyring entry")?;
match entry.get_password() {
Ok(key) => Ok(Some(key)),
Err(keyring::Error::NoEntry) => Ok(None),
Err(e) => Err(e.into()),
}
}
pub fn delete_api_key(&self, provider: &str) -> Result<()> {
if !self.is_available() {
bail!("Keyring is not available on this system");
}
let entry = keyring::Entry::new(SERVICE_NAME, provider)
.context("Failed to create keyring entry")?;
entry
.delete_credential()
.context("Failed to delete API key")?;
Ok(())
}
pub fn has_api_key(&self, provider: &str) -> bool {
self.get_api_key(provider).unwrap_or(None).is_some()
}
fn make_pat_service_name(profile_name: &str) -> String {
format!("{}/{}", PAT_SERVICE_PREFIX, profile_name)
}
pub fn store_pat(
&self,
profile_name: &str,
user_email: &str,
service: &str,
token: &str,
) -> Result<()> {
if !self.is_available() {
bail!("Keyring is not available on this system");
}
let keyring_service = Self::make_pat_service_name(profile_name);
let keyring_user = format!("{}:{}", user_email, service);
let entry = keyring::Entry::new(&keyring_service, &keyring_user)
.context("Failed to create keyring entry for PAT")?;
entry
.set_password(token)
.context("Failed to store PAT in keyring")?;
eprintln!(
"[DEBUG] PAT stored in keyring: service={}, user={}",
keyring_service, keyring_user
);
Ok(())
}
pub fn get_pat(
&self,
profile_name: &str,
user_email: &str,
service: &str,
) -> Result<Option<String>> {
if !self.is_available() {
return Ok(None);
}
let keyring_service = Self::make_pat_service_name(profile_name);
let keyring_user = format!("{}:{}", user_email, service);
let entry = keyring::Entry::new(&keyring_service, &keyring_user)
.context("Failed to create keyring entry for PAT")?;
match entry.get_password() {
Ok(token) => {
eprintln!(
"[DEBUG] PAT retrieved from keyring: service={}, user={}",
keyring_service, keyring_user
);
Ok(Some(token))
}
Err(keyring::Error::NoEntry) => {
eprintln!(
"[DEBUG] PAT not found in keyring: service={}, user={}",
keyring_service, keyring_user
);
Ok(None)
}
Err(e) => Err(e.into()),
}
}
pub fn delete_pat(&self, profile_name: &str, user_email: &str, service: &str) -> Result<()> {
if !self.is_available() {
bail!("Keyring is not available on this system");
}
let keyring_service = Self::make_pat_service_name(profile_name);
let keyring_user = format!("{}:{}", user_email, service);
let entry = keyring::Entry::new(&keyring_service, &keyring_user)
.context("Failed to create keyring entry for PAT")?;
entry
.delete_credential()
.context("Failed to delete PAT from keyring")?;
eprintln!(
"[DEBUG] PAT deleted from keyring: service={}, user={}",
keyring_service, keyring_user
);
Ok(())
}
pub fn has_pat(&self, profile_name: &str, user_email: &str, service: &str) -> bool {
self.get_pat(profile_name, user_email, service)
.unwrap_or(None)
.is_some()
}
pub fn delete_all_pats_for_profile(
&self,
profile_name: &str,
user_email: &str,
services: &[String],
) -> Result<()> {
for service in services {
if let Err(e) = self.delete_pat(profile_name, user_email, service) {
eprintln!(
"[DEBUG] Failed to delete PAT for service '{}': {}",
service, e
);
}
}
Ok(())
}
pub fn get_status_message(&self) -> String {
match self.status {
KeyringStatus::Available => {
#[cfg(target_os = "windows")]
{
"Windows Credential Manager is available".to_string()
}
#[cfg(target_os = "macos")]
{
"macOS Keychain is available".to_string()
}
#[cfg(target_os = "linux")]
{
"Linux secret service is available".to_string()
}
#[cfg(not(any(target_os = "windows", target_os = "macos", target_os = "linux")))]
{
"Keyring is available".to_string()
}
}
KeyringStatus::Unavailable => {
"Keyring is not available. Set QUICOMMIT_API_KEY environment variable.".to_string()
}
}
}
}
impl Default for KeyringManager {
fn default() -> Self {
Self::new()
}
}
pub fn get_default_base_url(provider: &str) -> &'static str {
match provider {
"openai" => "https://api.openai.com/v1",
"anthropic" => "https://api.anthropic.com/v1",
"kimi" => "https://api.moonshot.cn/v1",
"deepseek" => "https://api.deepseek.com/v1",
"openrouter" => "https://openrouter.ai/api/v1",
"ollama" => "http://localhost:11434",
_ => "",
}
}
pub fn get_default_model(provider: &str) -> &'static str {
match provider {
"openai" => "gpt-4",
"anthropic" => "claude-3-sonnet-20240229",
"kimi" => "kimi-k2.6",
"deepseek" => "deepseek-v4-flash",
"openrouter" => "openai/gpt-3.5-turbo",
"ollama" => "llama2",
_ => "",
}
}
pub fn get_supported_providers() -> &'static [&'static str] {
&[
"ollama",
"openai",
"anthropic",
"kimi",
"deepseek",
"openrouter",
]
}
pub fn provider_needs_api_key(provider: &str) -> bool {
provider != "ollama"
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_default_base_url() {
assert_eq!(get_default_base_url("openai"), "https://api.openai.com/v1");
assert_eq!(
get_default_base_url("anthropic"),
"https://api.anthropic.com/v1"
);
assert_eq!(get_default_base_url("kimi"), "https://api.moonshot.cn/v1");
assert_eq!(
get_default_base_url("deepseek"),
"https://api.deepseek.com/v1"
);
assert_eq!(
get_default_base_url("openrouter"),
"https://openrouter.ai/api/v1"
);
assert_eq!(get_default_base_url("ollama"), "http://localhost:11434");
}
#[test]
fn test_get_default_model() {
assert_eq!(get_default_model("openai"), "gpt-4");
assert_eq!(get_default_model("anthropic"), "claude-3-sonnet-20240229");
assert_eq!(get_default_model("ollama"), "llama2");
}
#[test]
fn test_provider_needs_api_key() {
assert!(provider_needs_api_key("openai"));
assert!(provider_needs_api_key("anthropic"));
assert!(!provider_needs_api_key("ollama"));
}
}

View File

@@ -1,6 +1,7 @@
pub mod crypto; pub mod crypto;
pub mod editor; pub mod editor;
pub mod formatter; pub mod formatter;
pub mod keyring;
pub mod validators; pub mod validators;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
@@ -31,10 +32,10 @@ pub fn print_info(msg: &str) {
pub fn confirm(prompt: &str) -> Result<bool> { pub fn confirm(prompt: &str) -> Result<bool> {
print!("{} [y/N] ", prompt); print!("{} [y/N] ", prompt);
io::stdout().flush()?; io::stdout().flush()?;
let mut input = String::new(); let mut input = String::new();
io::stdin().read_line(&mut input)?; io::stdin().read_line(&mut input)?;
Ok(input.trim().to_lowercase().starts_with('y')) Ok(input.trim().to_lowercase().starts_with('y'))
} }
@@ -42,17 +43,17 @@ pub fn confirm(prompt: &str) -> Result<bool> {
pub fn input(prompt: &str) -> Result<String> { pub fn input(prompt: &str) -> Result<String> {
print!("{}: ", prompt); print!("{}: ", prompt);
io::stdout().flush()?; io::stdout().flush()?;
let mut input = String::new(); let mut input = String::new();
io::stdin().read_line(&mut input)?; io::stdin().read_line(&mut input)?;
Ok(input.trim().to_string()) Ok(input.trim().to_string())
} }
/// Get password input (hidden) /// Get password input (hidden)
pub fn password_input(prompt: &str) -> Result<String> { pub fn password_input(prompt: &str) -> Result<String> {
use dialoguer::Password; use dialoguer::Password;
Password::new() Password::new()
.with_prompt(prompt) .with_prompt(prompt)
.interact() .interact()

View File

@@ -1,4 +1,4 @@
use anyhow::{bail, Result}; use anyhow::{Result, bail};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
@@ -67,7 +67,7 @@ lazy_static! {
/// Validate conventional commit message /// Validate conventional commit message
pub fn validate_conventional_commit(message: &str) -> Result<()> { pub fn validate_conventional_commit(message: &str) -> Result<()> {
let first_line = message.lines().next().unwrap_or(""); let first_line = message.lines().next().unwrap_or("");
if !CONVENTIONAL_COMMIT_REGEX.is_match(first_line) { if !CONVENTIONAL_COMMIT_REGEX.is_match(first_line) {
bail!( bail!(
"Invalid conventional commit format. Expected: <type>[optional scope]: <description>\n\ "Invalid conventional commit format. Expected: <type>[optional scope]: <description>\n\
@@ -75,32 +75,32 @@ pub fn validate_conventional_commit(message: &str) -> Result<()> {
CONVENTIONAL_TYPES.join(", ") CONVENTIONAL_TYPES.join(", ")
); );
} }
if first_line.len() > 100 { if first_line.len() > 100 {
bail!("Commit subject too long (max 100 characters)"); bail!("Commit subject too long (max 100 characters)");
} }
Ok(()) Ok(())
} }
/// Validate @commitlint commit message /// Validate @commitlint commit message
pub fn validate_commitlint_commit(message: &str) -> Result<()> { pub fn validate_commitlint_commit(message: &str) -> Result<()> {
let first_line = message.lines().next().unwrap_or(""); let first_line = message.lines().next().unwrap_or("");
let parts: Vec<&str> = first_line.splitn(2, ':').collect(); let parts: Vec<&str> = first_line.splitn(2, ':').collect();
if parts.len() != 2 { if parts.len() != 2 {
bail!("Invalid commit format. Expected: <type>[optional scope]: <subject>"); bail!("Invalid commit format. Expected: <type>[optional scope]: <subject>");
} }
let type_part = parts[0]; let type_part = parts[0];
let subject = parts[1].trim(); let subject = parts[1].trim();
let commit_type = type_part let commit_type = type_part
.split('(') .split('(')
.next() .next()
.unwrap_or("") .unwrap_or("")
.trim_end_matches('!'); .trim_end_matches('!');
if !COMMITLINT_TYPES.contains(&commit_type) { if !COMMITLINT_TYPES.contains(&commit_type) {
bail!( bail!(
"Invalid commit type: '{}'. Valid types: {}", "Invalid commit type: '{}'. Valid types: {}",
@@ -108,27 +108,32 @@ pub fn validate_commitlint_commit(message: &str) -> Result<()> {
COMMITLINT_TYPES.join(", ") COMMITLINT_TYPES.join(", ")
); );
} }
if subject.is_empty() { if subject.is_empty() {
bail!("Commit subject cannot be empty"); bail!("Commit subject cannot be empty");
} }
if subject.len() < 4 { if subject.len() < 4 {
bail!("Commit subject too short (min 4 characters)"); bail!("Commit subject too short (min 4 characters)");
} }
if subject.len() > 100 { if subject.len() > 100 {
bail!("Commit subject too long (max 100 characters)"); bail!("Commit subject too long (max 100 characters)");
} }
if subject.chars().next().map(|c| c.is_uppercase()).unwrap_or(false) { if subject
.chars()
.next()
.map(|c| c.is_uppercase())
.unwrap_or(false)
{
bail!("Commit subject should not start with uppercase letter"); bail!("Commit subject should not start with uppercase letter");
} }
if subject.ends_with('.') { if subject.ends_with('.') {
bail!("Commit subject should not end with a period"); bail!("Commit subject should not end with a period");
} }
Ok(()) Ok(())
} }
@@ -137,25 +142,25 @@ pub fn validate_scope(scope: &str) -> Result<()> {
if scope.is_empty() { if scope.is_empty() {
bail!("Scope cannot be empty"); bail!("Scope cannot be empty");
} }
if !SCOPE_REGEX.is_match(scope) { if !SCOPE_REGEX.is_match(scope) {
bail!("Invalid scope format. Use lowercase letters, numbers, and hyphens only"); bail!("Invalid scope format. Use lowercase letters, numbers, and hyphens only");
} }
Ok(()) Ok(())
} }
/// Validate semantic version tag /// Validate semantic version tag
pub fn validate_semver(version: &str) -> Result<()> { pub fn validate_semver(version: &str) -> Result<()> {
let version = version.trim_start_matches('v'); let version = version.trim_start_matches('v');
if !SEMVER_REGEX.is_match(version) { if !SEMVER_REGEX.is_match(version) {
bail!( bail!(
"Invalid semantic version format. Expected: MAJOR.MINOR.PATCH[-prerelease][+build]\n\ "Invalid semantic version format. Expected: MAJOR.MINOR.PATCH[-prerelease][+build]\n\
Examples: 1.0.0, 1.2.3-beta, v2.0.0+build123" Examples: 1.0.0, 1.2.3-beta, v2.0.0+build123"
); );
} }
Ok(()) Ok(())
} }
@@ -164,7 +169,7 @@ pub fn validate_email(email: &str) -> Result<()> {
if !EMAIL_REGEX.is_match(email) { if !EMAIL_REGEX.is_match(email) {
bail!("Invalid email address format"); bail!("Invalid email address format");
} }
Ok(()) Ok(())
} }
@@ -173,7 +178,7 @@ pub fn validate_gpg_key_id(key_id: &str) -> Result<()> {
if !GPG_KEY_ID_REGEX.is_match(key_id) { if !GPG_KEY_ID_REGEX.is_match(key_id) {
bail!("Invalid GPG key ID format. Expected 16-40 hexadecimal characters"); bail!("Invalid GPG key ID format. Expected 16-40 hexadecimal characters");
} }
Ok(()) Ok(())
} }
@@ -182,15 +187,18 @@ pub fn validate_profile_name(name: &str) -> Result<()> {
if name.is_empty() { if name.is_empty() {
bail!("Profile name cannot be empty"); bail!("Profile name cannot be empty");
} }
if name.len() > 50 { if name.len() > 50 {
bail!("Profile name too long (max 50 characters)"); bail!("Profile name too long (max 50 characters)");
} }
if !name.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_') { if !name
.chars()
.all(|c| c.is_alphanumeric() || c == '-' || c == '_')
{
bail!("Profile name can only contain letters, numbers, hyphens, and underscores"); bail!("Profile name can only contain letters, numbers, hyphens, and underscores");
} }
Ok(()) Ok(())
} }
@@ -201,7 +209,7 @@ pub fn is_valid_commit_type(commit_type: &str, use_commitlint: bool) -> bool {
} else { } else {
CONVENTIONAL_TYPES CONVENTIONAL_TYPES
}; };
types.contains(&commit_type) types.contains(&commit_type)
} }

View File

@@ -0,0 +1,434 @@
use assert_cmd::cargo::cargo_bin_cmd;
use predicates::prelude::*;
use std::fs;
use std::path::PathBuf;
use tempfile::TempDir;
fn init_quicommit(config_path: &PathBuf) {
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
}
mod config_export {
use super::*;
#[test]
fn test_export_to_stdout() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
init_quicommit(&config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("version"))
.stdout(predicate::str::contains("[llm]"));
}
#[test]
fn test_export_to_file() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let export_path = temp_dir.path().join("exported.toml");
init_quicommit(&config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path.to_str().unwrap(),
"--output",
export_path.to_str().unwrap(),
"--password",
"",
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("Configuration exported"));
assert!(export_path.exists(), "Export file should be created");
let content = fs::read_to_string(&export_path).unwrap();
assert!(content.contains("version"), "Export should contain version");
assert!(
content.contains("[llm]"),
"Export should contain LLM config"
);
}
#[test]
fn test_export_encrypted() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let export_path = temp_dir.path().join("encrypted.toml");
init_quicommit(&config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path.to_str().unwrap(),
"--output",
export_path.to_str().unwrap(),
"--password",
"test_password_123",
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("encrypted and exported"));
assert!(export_path.exists(), "Export file should be created");
let content = fs::read_to_string(&export_path).unwrap();
assert!(
content.starts_with("ENCRYPTED:"),
"Encrypted file should start with ENCRYPTED:"
);
assert!(
!content.contains("[llm]"),
"Encrypted content should not be readable"
);
}
}
mod config_import {
use super::*;
#[test]
fn test_import_plain_config() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let import_path = temp_dir.path().join("import.toml");
let plain_config = r#"
version = "1"
[llm]
provider = "openai"
model = "gpt-4"
max_tokens = 1000
temperature = 0.7
timeout = 60
api_key_storage = "keyring"
[commit]
format = "conventional"
auto_generate = true
allow_empty = false
gpg_sign = false
max_subject_length = 100
require_scope = false
require_body = false
body_required_types = ["feat", "fix"]
[tag]
version_prefix = "v"
auto_generate = true
gpg_sign = false
include_changelog = true
[changelog]
path = "CHANGELOG.md"
auto_generate = true
format = "keep-a-changelog"
include_hashes = false
include_authors = false
group_by_type = true
[theme]
colors = true
icons = true
date_format = "%Y-%m-%d"
[language]
output_language = "en"
keep_types_english = true
keep_changelog_types_english = true
"#;
fs::write(&import_path, plain_config).unwrap();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"import",
"--config",
config_path.to_str().unwrap(),
"--file",
import_path.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("Configuration imported"));
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"get",
"llm.provider",
"--config",
config_path.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("openai"));
}
#[test]
fn test_import_encrypted_config() {
let temp_dir = TempDir::new().unwrap();
let config_path1 = temp_dir.path().join("config1.toml");
let config_path2 = temp_dir.path().join("config2.toml");
let export_path = temp_dir.path().join("encrypted.toml");
init_quicommit(&config_path1);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"set",
"llm.provider",
"anthropic",
"--config",
config_path1.to_str().unwrap(),
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path1.to_str().unwrap(),
"--output",
export_path.to_str().unwrap(),
"--password",
"secure_password",
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"import",
"--config",
config_path2.to_str().unwrap(),
"--file",
export_path.to_str().unwrap(),
"--password",
"secure_password",
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("Configuration imported"));
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"get",
"llm.provider",
"--config",
config_path2.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("anthropic"));
}
#[test]
fn test_import_encrypted_wrong_password() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let export_path = temp_dir.path().join("encrypted.toml");
init_quicommit(&config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path.to_str().unwrap(),
"--output",
export_path.to_str().unwrap(),
"--password",
"correct_password",
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"import",
"--config",
config_path.to_str().unwrap(),
"--file",
export_path.to_str().unwrap(),
"--password",
"wrong_password",
]);
cmd.assert()
.failure()
.stderr(predicate::str::contains("Failed to decrypt"));
}
}
mod config_export_import_roundtrip {
use super::*;
#[test]
fn test_roundtrip_plain() {
let temp_dir = TempDir::new().unwrap();
let config_path1 = temp_dir.path().join("config1.toml");
let config_path2 = temp_dir.path().join("config2.toml");
let export_path = temp_dir.path().join("export.toml");
init_quicommit(&config_path1);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"set",
"llm.model",
"gpt-4-turbo",
"--config",
config_path1.to_str().unwrap(),
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path1.to_str().unwrap(),
"--output",
export_path.to_str().unwrap(),
"--password",
"",
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"import",
"--config",
config_path2.to_str().unwrap(),
"--file",
export_path.to_str().unwrap(),
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"get",
"llm.model",
"--config",
config_path2.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("gpt-4-turbo"));
}
#[test]
fn test_roundtrip_encrypted() {
let temp_dir = TempDir::new().unwrap();
let config_path1 = temp_dir.path().join("config1.toml");
let config_path2 = temp_dir.path().join("config2.toml");
let export_path = temp_dir.path().join("encrypted.toml");
let password = "my_secure_password_123";
init_quicommit(&config_path1);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"set",
"llm.provider",
"deepseek",
"--config",
config_path1.to_str().unwrap(),
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"set",
"llm.model",
"deepseek-chat",
"--config",
config_path1.to_str().unwrap(),
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"export",
"--config",
config_path1.to_str().unwrap(),
"--output",
export_path.to_str().unwrap(),
"--password",
password,
]);
cmd.assert().success();
let exported_content = fs::read_to_string(&export_path).unwrap();
assert!(exported_content.starts_with("ENCRYPTED:"));
assert!(!exported_content.contains("deepseek"));
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"import",
"--config",
config_path2.to_str().unwrap(),
"--file",
export_path.to_str().unwrap(),
"--password",
password,
]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"get",
"llm.provider",
"--config",
config_path2.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("deepseek"));
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"get",
"llm.model",
"--config",
config_path2.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("deepseek-chat"));
}
}

View File

@@ -1,4 +1,4 @@
use assert_cmd::Command; use assert_cmd::cargo::cargo_bin_cmd;
use predicates::prelude::*; use predicates::prelude::*;
use std::fs; use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
@@ -47,22 +47,43 @@ fn create_commit(dir: &PathBuf, message: &str) {
.expect("Failed to create commit"); .expect("Failed to create commit");
} }
fn setup_git_repo(dir: &PathBuf) {
create_git_repo(dir);
configure_git_user(dir);
}
fn setup_test_repo_with_file(dir: &PathBuf, file_name: &str, file_content: &str) {
setup_git_repo(dir);
create_test_file(dir, file_name, file_content);
stage_file(dir, file_name);
}
fn init_quicommit(dir: &PathBuf, config_path: &PathBuf) {
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(dir);
cmd.assert().success();
}
mod cli_basic { mod cli_basic {
use super::*; use super::*;
#[test] #[test]
fn test_help() { fn test_help() {
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.arg("--help"); cmd.arg("--help");
cmd.assert() cmd.assert()
.success() .success()
.stdout(predicate::str::contains("QuiCommit")) .stdout(predicate::str::contains("QuiCommit"))
.stdout(predicate::str::contains("AI-powered Git assistant")); .stdout(predicate::str::contains("AI-powered Git assistant"))
.stdout(predicate::str::contains("Usage:"))
.stdout(predicate::str::contains("Commands:"))
.stdout(predicate::str::contains("Options:"));
} }
#[test] #[test]
fn test_version() { fn test_version() {
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.arg("--version"); cmd.arg("--version");
cmd.assert() cmd.assert()
.success() .success()
@@ -71,7 +92,7 @@ mod cli_basic {
#[test] #[test]
fn test_no_args_shows_help() { fn test_no_args_shows_help() {
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.assert() cmd.assert()
.failure() .failure()
.stderr(predicate::str::contains("Usage:")); .stderr(predicate::str::contains("Usage:"));
@@ -85,9 +106,15 @@ mod cli_basic {
create_git_repo(&repo_path); create_git_repo(&repo_path);
configure_git_user(&repo_path); configure_git_user(&repo_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["-vv", "init", "--yes", "--config", config_path.to_str().unwrap()]) cmd.args(&[
.current_dir(&repo_path); "-vv",
"init",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert().success(); cmd.assert().success();
} }
@@ -101,7 +128,7 @@ mod init_command {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml"); let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert() cmd.assert()
@@ -114,7 +141,7 @@ mod init_command {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml"); let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success(); cmd.assert().success();
@@ -131,7 +158,7 @@ mod init_command {
let config_path = repo_path.join("test_config.toml"); let config_path = repo_path.join("test_config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]) cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path); .current_dir(&repo_path);
@@ -143,12 +170,18 @@ mod init_command {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml"); let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success(); cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--reset", "--config", config_path.to_str().unwrap()]); cmd.args(&[
"init",
"--yes",
"--reset",
"--config",
config_path.to_str().unwrap(),
]);
cmd.assert() cmd.assert()
.success() .success()
.stdout(predicate::str::contains("initialized successfully")); .stdout(predicate::str::contains("initialized successfully"));
@@ -163,7 +196,7 @@ mod profile_command {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml"); let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["profile", "list", "--config", config_path.to_str().unwrap()]); cmd.args(&["profile", "list", "--config", config_path.to_str().unwrap()]);
cmd.assert() cmd.assert()
@@ -176,11 +209,11 @@ mod profile_command {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml"); let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success(); cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["profile", "list", "--config", config_path.to_str().unwrap()]); cmd.args(&["profile", "list", "--config", config_path.to_str().unwrap()]);
cmd.assert() cmd.assert()
@@ -197,11 +230,11 @@ mod config_command {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml"); let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success(); cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["config", "show", "--config", config_path.to_str().unwrap()]); cmd.args(&["config", "show", "--config", config_path.to_str().unwrap()]);
cmd.assert() cmd.assert()
@@ -214,11 +247,11 @@ mod config_command {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml"); let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success(); cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["config", "path", "--config", config_path.to_str().unwrap()]); cmd.args(&["config", "path", "--config", config_path.to_str().unwrap()]);
cmd.assert() cmd.assert()
@@ -235,13 +268,19 @@ mod commit_command {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml"); let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success(); cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["commit", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) cmd.args(&[
.current_dir(temp_dir.path()); "commit",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(temp_dir.path());
cmd.assert() cmd.assert()
.failure() .failure()
@@ -252,19 +291,23 @@ mod commit_command {
fn test_commit_no_changes() { fn test_commit_no_changes() {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf(); let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path); setup_git_repo(&repo_path);
configure_git_user(&repo_path);
let config_path = repo_path.join("config.toml"); let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]) cmd.args(&[
.current_dir(&repo_path); "commit",
cmd.assert().success(); "--manual",
"-m",
let mut cmd = Command::cargo_bin("quicommit").unwrap(); "test: empty",
cmd.args(&["commit", "--manual", "-m", "test: empty", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) "--dry-run",
.current_dir(&repo_path); "--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert() cmd.assert()
.success() .success()
@@ -275,22 +318,23 @@ mod commit_command {
fn test_commit_with_staged_changes() { fn test_commit_with_staged_changes() {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf(); let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path); setup_test_repo_with_file(&repo_path, "test.txt", "Hello, World!");
configure_git_user(&repo_path);
create_test_file(&repo_path, "test.txt", "Hello, World!");
stage_file(&repo_path, "test.txt");
let config_path = repo_path.join("config.toml"); let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]) cmd.args(&[
.current_dir(&repo_path); "commit",
cmd.assert().success(); "--manual",
"-m",
let mut cmd = Command::cargo_bin("quicommit").unwrap(); "test: add test file",
cmd.args(&["commit", "--manual", "-m", "test: add test file", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) "--dry-run",
.current_dir(&repo_path); "--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert() cmd.assert()
.success() .success()
@@ -301,27 +345,52 @@ mod commit_command {
fn test_commit_date_mode() { fn test_commit_date_mode() {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf(); let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path); setup_test_repo_with_file(&repo_path, "daily.txt", "Daily update");
configure_git_user(&repo_path);
create_test_file(&repo_path, "daily.txt", "Daily update");
stage_file(&repo_path, "daily.txt");
let config_path = repo_path.join("config.toml"); let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]) cmd.args(&[
.current_dir(&repo_path); "commit",
cmd.assert().success(); "--date",
"--dry-run",
let mut cmd = Command::cargo_bin("quicommit").unwrap(); "--yes",
cmd.args(&["commit", "--date", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) "--config",
.current_dir(&repo_path); config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert() cmd.assert()
.success() .success()
.stdout(predicate::str::contains("Dry run")); .stdout(predicate::str::contains("Dry run"));
} }
#[test]
fn test_commit_with_think_flag() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
setup_test_repo_with_file(&repo_path, "test.txt", "Hello, World!");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"commit",
"--think",
"--manual",
"-m",
"test: think flag",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert().success();
}
} }
mod tag_command { mod tag_command {
@@ -332,13 +401,19 @@ mod tag_command {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml"); let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success(); cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["tag", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) cmd.args(&[
.current_dir(temp_dir.path()); "tag",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(temp_dir.path());
cmd.assert() cmd.assert()
.failure() .failure()
@@ -349,28 +424,60 @@ mod tag_command {
fn test_tag_list_empty() { fn test_tag_list_empty() {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf(); let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path); setup_git_repo(&repo_path);
configure_git_user(&repo_path);
create_test_file(&repo_path, "test.txt", "content"); create_test_file(&repo_path, "test.txt", "content");
stage_file(&repo_path, "test.txt"); stage_file(&repo_path, "test.txt");
create_commit(&repo_path, "feat: initial commit"); create_commit(&repo_path, "feat: initial commit");
let config_path = repo_path.join("config.toml"); let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]) cmd.args(&[
.current_dir(&repo_path); "tag",
cmd.assert().success(); "--name",
"v0.1.0",
let mut cmd = Command::cargo_bin("quicommit").unwrap(); "--dry-run",
cmd.args(&["tag", "--name", "v0.1.0", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) "--yes",
.current_dir(&repo_path); "--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert() cmd.assert()
.success() .success()
.stdout(predicate::str::contains("v0.1.0")); .stdout(predicate::str::contains("v0.1.0"));
} }
#[test]
fn test_tag_with_think_flag() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
setup_git_repo(&repo_path);
create_test_file(&repo_path, "test.txt", "content");
stage_file(&repo_path, "test.txt");
create_commit(&repo_path, "feat: initial commit");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"tag",
"--think",
"--name",
"v0.2.0",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert().success();
}
} }
mod changelog_command { mod changelog_command {
@@ -380,20 +487,23 @@ mod changelog_command {
fn test_changelog_init() { fn test_changelog_init() {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf(); let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path); setup_git_repo(&repo_path);
configure_git_user(&repo_path);
let config_path = repo_path.join("config.toml"); let config_path = repo_path.join("config.toml");
let changelog_path = repo_path.join("CHANGELOG.md"); let changelog_path = repo_path.join("CHANGELOG.md");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); init_quicommit(&repo_path, &config_path);
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["changelog", "--init", "--output", changelog_path.to_str().unwrap(), "--config", config_path.to_str().unwrap()]) cmd.args(&[
.current_dir(&repo_path); "changelog",
"--init",
"--output",
changelog_path.to_str().unwrap(),
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert().success(); cmd.assert().success();
@@ -404,26 +514,26 @@ mod changelog_command {
fn test_changelog_dry_run() { fn test_changelog_dry_run() {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf(); let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path); setup_git_repo(&repo_path);
configure_git_user(&repo_path);
create_test_file(&repo_path, "test.txt", "content"); create_test_file(&repo_path, "test.txt", "content");
stage_file(&repo_path, "test.txt"); stage_file(&repo_path, "test.txt");
create_commit(&repo_path, "feat: add feature"); create_commit(&repo_path, "feat: add feature");
let config_path = repo_path.join("config.toml"); let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"changelog",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert().success(); cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap();
cmd.args(&["changelog", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
.current_dir(&repo_path);
cmd.assert()
.success();
} }
} }
@@ -435,7 +545,7 @@ mod cross_platform {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("subdir").join("config.toml"); let config_path = temp_dir.path().join("subdir").join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success(); cmd.assert().success();
@@ -449,7 +559,7 @@ mod cross_platform {
fs::create_dir_all(&space_dir).unwrap(); fs::create_dir_all(&space_dir).unwrap();
let config_path = space_dir.join("config.toml"); let config_path = space_dir.join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success(); cmd.assert().success();
@@ -463,7 +573,7 @@ mod cross_platform {
fs::create_dir_all(&unicode_dir).unwrap(); fs::create_dir_all(&unicode_dir).unwrap();
let config_path = unicode_dir.join("config.toml"); let config_path = unicode_dir.join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success(); cmd.assert().success();
@@ -532,22 +642,23 @@ mod validators {
fn test_commit_message_validation() { fn test_commit_message_validation() {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf(); let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path); setup_test_repo_with_file(&repo_path, "test.txt", "content");
configure_git_user(&repo_path);
create_test_file(&repo_path, "test.txt", "content");
stage_file(&repo_path, "test.txt");
let config_path = repo_path.join("config.toml"); let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]) cmd.args(&[
.current_dir(&repo_path); "commit",
cmd.assert().success(); "--manual",
"-m",
let mut cmd = Command::cargo_bin("quicommit").unwrap(); "invalid commit message without type",
cmd.args(&["commit", "--manual", "-m", "invalid commit message without type", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) "--dry-run",
.current_dir(&repo_path); "--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert() cmd.assert()
.failure() .failure()
@@ -558,22 +669,23 @@ mod validators {
fn test_valid_conventional_commit() { fn test_valid_conventional_commit() {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf(); let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path); setup_test_repo_with_file(&repo_path, "test.txt", "content");
configure_git_user(&repo_path);
create_test_file(&repo_path, "test.txt", "content");
stage_file(&repo_path, "test.txt");
let config_path = repo_path.join("config.toml"); let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]) cmd.args(&[
.current_dir(&repo_path); "commit",
cmd.assert().success(); "--manual",
"-m",
let mut cmd = Command::cargo_bin("quicommit").unwrap(); "feat: add new feature",
cmd.args(&["commit", "--manual", "-m", "feat: add new feature", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) "--dry-run",
.current_dir(&repo_path); "--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert() cmd.assert()
.success() .success()
@@ -588,22 +700,23 @@ mod subcommands {
fn test_commit_alias() { fn test_commit_alias() {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf(); let repo_path = temp_dir.path().to_path_buf();
create_git_repo(&repo_path); setup_test_repo_with_file(&repo_path, "test.txt", "content");
configure_git_user(&repo_path);
create_test_file(&repo_path, "test.txt", "content");
stage_file(&repo_path, "test.txt");
let config_path = repo_path.join("config.toml"); let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]) cmd.args(&[
.current_dir(&repo_path); "c",
cmd.assert().success(); "--manual",
"-m",
let mut cmd = Command::cargo_bin("quicommit").unwrap(); "fix: test",
cmd.args(&["c", "--manual", "-m", "fix: test", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) "--dry-run",
.current_dir(&repo_path); "--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert() cmd.assert()
.success() .success()
@@ -615,7 +728,7 @@ mod subcommands {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml"); let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["i", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["i", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert() cmd.assert()
@@ -628,11 +741,11 @@ mod subcommands {
let temp_dir = TempDir::new().unwrap(); let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml"); let config_path = temp_dir.path().join("config.toml");
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success(); cmd.assert().success();
let mut cmd = Command::cargo_bin("quicommit").unwrap(); let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["p", "list", "--config", config_path.to_str().unwrap()]); cmd.args(&["p", "list", "--config", config_path.to_str().unwrap()]);
cmd.assert() cmd.assert()
@@ -640,3 +753,79 @@ mod subcommands {
.stdout(predicate::str::contains("default")); .stdout(predicate::str::contains("default"));
} }
} }
mod edge_cases {
use super::*;
#[test]
fn test_config_file_not_found() {
let temp_dir = TempDir::new().unwrap();
let non_existent_config = temp_dir.path().join("non_existent_config.toml");
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"config",
"show",
"--config",
non_existent_config.to_str().unwrap(),
]);
cmd.assert()
.success()
.stdout(predicate::str::contains("QuiCommit Configuration"))
.stdout(predicate::str::contains("Default profile: (none)"))
.stdout(predicate::str::contains("Profiles: 0"));
}
#[test]
fn test_invalid_git_repo() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
let config_path = repo_path.join("config.toml");
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
cmd.assert().success();
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"commit",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert()
.failure()
.stderr(predicate::str::contains("git").or(predicate::str::contains("repository")));
}
#[test]
fn test_empty_commit_message() {
let temp_dir = TempDir::new().unwrap();
let repo_path = temp_dir.path().to_path_buf();
setup_test_repo_with_file(&repo_path, "test.txt", "content");
let config_path = repo_path.join("config.toml");
init_quicommit(&repo_path, &config_path);
let mut cmd = cargo_bin_cmd!("quicommit");
cmd.args(&[
"commit",
"--manual",
"-m",
"",
"--dry-run",
"--yes",
"--config",
config_path.to_str().unwrap(),
])
.current_dir(&repo_path);
cmd.assert().failure().stderr(predicate::str::contains(
"Invalid conventional commit format",
));
}
}