diff --git a/.gitignore b/.gitignore index dd0bb53..a5de515 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,5 @@ test_output/ # Config (for development) config.toml +.claude/ +CLAUDE.md \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index e960052..00087dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "quicommit" -version = "0.2.1" +version = "0.2.3" edition = "2024" authors = ["Sidney Zhang "] description = "A powerful Git assistant tool with AI-powered commit/tag/changelog generation(alpha version)" @@ -33,7 +33,7 @@ git2 = "0.20.3" which = "6.0" # HTTP client for LLM APIs -reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false } +reqwest = { version = "0.12", features = ["json", "rustls-tls", "stream"], default-features = false } tokio = { version = "1.35", features = ["full"] } # Error handling @@ -57,6 +57,7 @@ sha2 = "0.10" hex = "0.4" textwrap = "0.16" async-trait = "0.1" +futures-util = "0.3" serde_json = "1.0" atty = "0.2" diff --git a/src/commands/changelog.rs b/src/commands/changelog.rs index 9369950..f391d7d 100644 --- a/src/commands/changelog.rs +++ b/src/commands/changelog.rs @@ -55,6 +55,10 @@ pub struct ChangelogCommand { #[arg(long)] dry_run: bool, + /// Enable thinking mode for this changelog (override config) + #[arg(long)] + think: bool, + /// Skip interactive prompts #[arg(short = 'y', long)] yes: bool, @@ -74,8 +78,7 @@ impl ChangelogCommand { // Initialize changelog if requested if self.init { - let path = self.output.as_ref() - .map(|p| p.clone()) + let path = self.output.clone() .unwrap_or_else(|| PathBuf::from(&config.changelog.path)); init_changelog(&path)?; @@ -84,8 +87,7 @@ impl ChangelogCommand { } // Determine output path - let output_path = self.output.as_ref() - .map(|p| p.clone()) + let output_path = self.output.clone() .unwrap_or_else(|| PathBuf::from(&config.changelog.path)); // Determine format @@ -148,7 +150,7 @@ impl ChangelogCommand { println!("{}", "─".repeat(60)); let confirm = Confirm::new() - .with_prompt(&messages.write_to_file(&format!("{:?}", output_path))) + .with_prompt(messages.write_to_file(&format!("{:?}", output_path))) .default(true) .interact()?; @@ -208,7 +210,7 @@ impl ChangelogCommand { println!("{}", messages.ai_generating_changelog()); - let generator = ContentGenerator::new(&manager).await?; + let generator = ContentGenerator::new_with_think(&manager, self.think).await?; generator.generate_changelog_entry(version, commits, language).await } @@ -237,7 +239,8 @@ impl ChangelogCommand { } fn translate_changelog_categories(&self, changelog: &str, language: Language) -> String { - let translated = changelog + + changelog .lines() .map(|line| { if line.starts_with("## ") || line.starts_with("### ") { @@ -253,7 +256,6 @@ impl ChangelogCommand { } }) .collect::>() - .join("\n"); - translated + .join("\n") } } diff --git a/src/commands/commit.rs b/src/commands/commit.rs index 5685669..de69f6e 100644 --- a/src/commands/commit.rs +++ b/src/commands/commit.rs @@ -71,6 +71,10 @@ pub struct CommitCommand { #[arg(long)] no_verify: bool, + /// Enable thinking mode for this commit (override config) + #[arg(short = 't', long)] + think: bool, + /// Skip interactive prompts #[arg(short = 'y', long)] yes: bool, @@ -198,7 +202,7 @@ impl CommitCommand { if let Some(commit_oid) = result { println!("{} {}", messages.commit_created().green().bold(), commit_oid.to_string()[..8].to_string().cyan()); } else { - println!("{} {}", messages.commit_amended().green().bold(), "successfully"); + println!("{} successfully", messages.commit_amended().green().bold()); } // Push after commit if requested or ask user @@ -258,7 +262,7 @@ impl CommitCommand { async fn generate_commit(&self, repo: &GitRepo, format: CommitFormat, messages: &Messages) -> Result { let manager = ConfigManager::new()?; - let generator = ContentGenerator::new(&manager).await + let generator = ContentGenerator::new_with_think(&manager, self.think).await .context("Failed to initialize LLM. Use --manual for manual commit.")?; println!("{}", messages.ai_analyzing()); diff --git a/src/commands/config.rs b/src/commands/config.rs index 9ece0ce..e95c4c0 100644 --- a/src/commands/config.rs +++ b/src/commands/config.rs @@ -393,8 +393,15 @@ impl ConfigCommand { let timeout: u64 = value.parse()?; manager.config_mut().llm.timeout = timeout; } + "llm.thinking_enabled" => { + manager.config_mut().llm.thinking_enabled = value == "true"; + } + "llm.thinking_budget_tokens" => { + let budget: u32 = value.parse()?; + manager.config_mut().llm.thinking_budget_tokens = Some(budget); + } "llm.api_key_storage" => { - let valid_values = vec!["keyring", "config", "environment"]; + let valid_values = ["keyring", "config", "environment"]; if !valid_values.contains(&value) { bail!("Invalid value: {}. Use: {}", value, valid_values.join(", ")); } @@ -433,6 +440,8 @@ impl ConfigCommand { "llm.max_tokens" => config.llm.max_tokens.to_string(), "llm.temperature" => config.llm.temperature.to_string(), "llm.timeout" => config.llm.timeout.to_string(), + "llm.thinking_enabled" => config.llm.thinking_enabled.to_string(), + "llm.thinking_budget_tokens" => config.llm.thinking_budget_tokens.map(|v| v.to_string()).unwrap_or_else(|| "none".to_string()), "llm.api_key_storage" => config.llm.api_key_storage.clone(), "commit.format" => config.commit.format.to_string(), "commit.auto_generate" => config.commit.auto_generate.to_string(), @@ -681,15 +690,13 @@ impl ConfigCommand { l.to_string() } else { println!("{}", "Select Output Language:".bold()); - let languages = vec![ - ("en", "English"), + let languages = [("en", "English"), ("zh", "中文"), ("ja", "日本語"), ("ko", "한국어"), ("es", "Español"), ("fr", "Français"), - ("de", "Deutsch"), - ]; + ("de", "Deutsch")]; let lang_names: Vec<&str> = languages.iter().map(|(_, n)| *n).collect(); let idx = Select::new() @@ -879,8 +886,8 @@ impl ConfigCommand { manager.import(&config_content)?; manager.save()?; - if let (Some(pats), Some(pwd)) = (encrypted_pats, pwd) { - if !pats.is_empty() { + if let (Some(pats), Some(pwd)) = (encrypted_pats, pwd) + && !pats.is_empty() { println!(); println!("{}", "Importing Personal Access Tokens...".bold()); @@ -951,7 +958,6 @@ impl ConfigCommand { println!("{} {} token(s) failed to import", "⚠".yellow(), failed_count); } } - } println!("{} Configuration imported from {}", "✓".green(), file); Ok(()) @@ -969,31 +975,35 @@ impl ConfigCommand { println!("Ollama models (local):"); println!(" llama2, llama2-uncensored, llama2:13b"); println!(" codellama, codellama:34b"); - println!(" mistral, mixtral"); - println!(" phi, gemma"); + println!(" mistral, mixtral, phi, gemma"); println!("\nRun 'ollama list' to see installed models"); } "openai" => { println!("OpenAI models:"); - println!(" gpt-4, gpt-4-turbo, gpt-4o"); - println!(" gpt-3.5-turbo, gpt-3.5-turbo-16k"); + println!(" o-series (reasoning): o4-mini, o3, o3-mini, o1, o1-mini, o1-pro"); + println!(" GPT-4: gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, gpt-4o, gpt-4o-mini"); + println!(" Legacy: gpt-4-turbo, gpt-4, gpt-3.5-turbo"); + println!("\nUse --think/-t with o-series models for reasoning mode."); } "anthropic" => { println!("Anthropic Claude models:"); - println!(" claude-3-opus-20240229"); - println!(" claude-3-sonnet-20240229"); - println!(" claude-3-haiku-20240307"); + println!(" Claude 4 (thinking): claude-opus-4-7, claude-sonnet-4-6, claude-haiku-4-5"); + println!(" Claude 3.5: claude-3-opus-20240229, claude-3-sonnet-20240229, claude-3-haiku-20240307"); + println!(" Legacy: claude-2.1, claude-2.0, claude-instant-1.2"); + println!("\nUse --think/-t with Claude 4 models for extended thinking."); } "kimi" => { println!("Kimi (Moonshot AI) models:"); - println!(" moonshot-v1-8k"); - println!(" moonshot-v1-32k"); - println!(" moonshot-v1-128k"); + println!(" K2 (thinking): kimi-k2.6, kimi-k2.5, kimi-k2-thinking, kimi-k2-thinking-turbo"); + println!(" K2 instruct: kimi-k2-instruct, kimi-k2-instruct-0905"); + println!(" Legacy: moonshot-v1-8k, moonshot-v1-32k, moonshot-v1-128k"); + println!("\nUse --think/-t with K2 models for thinking mode."); } "deepseek" => { println!("DeepSeek models:"); - println!(" deepseek-chat"); - println!(" deepseek-reasoner"); + println!(" V4: deepseek-v4-flash, deepseek-v4-pro"); + println!(" Legacy: deepseek-chat, deepseek-reasoner (deprecated 2026-07-24)"); + println!("\nUse --think/-t for reasoning mode."); } "openrouter" => { println!("OpenRouter models (examples):"); diff --git a/src/commands/init.rs b/src/commands/init.rs index 208340c..8983d0f 100644 --- a/src/commands/init.rs +++ b/src/commands/init.rs @@ -102,15 +102,13 @@ impl InitCommand { println!("\n{}", messages.setup_profile().bold()); println!("\n{}", messages.select_output_language().bold()); - let languages = vec![ - Language::English, + let languages = [Language::English, Language::Chinese, Language::Japanese, Language::Korean, Language::Spanish, Language::French, - Language::German, - ]; + Language::German]; let language_names: Vec = languages.iter().map(|l| l.display_name().to_string()).collect(); let language_idx = Select::new() .items(&language_names) @@ -297,12 +295,11 @@ impl InitCommand { manager.set_llm_model(model); manager.set_llm_base_url(base_url); - if let Some(key) = api_key { - if provider_needs_api_key(&provider) { + if let Some(key) = api_key + && provider_needs_api_key(&provider) { manager.set_api_key(&key)?; println!("\n{} {}", "✓".green(), "API key stored securely in system keyring.".green()); } - } Ok(()) } diff --git a/src/commands/profile.rs b/src/commands/profile.rs index cd250b0..01676fa 100644 --- a/src/commands/profile.rs +++ b/src/commands/profile.rs @@ -260,7 +260,7 @@ impl ProfileCommand { } let confirm = Confirm::new() - .with_prompt(&format!("Are you sure you want to remove profile '{}'?", name)) + .with_prompt(format!("Are you sure you want to remove profile '{}'?", name)) .default(false) .interact()?; @@ -564,7 +564,7 @@ impl ProfileCommand { if manager.has_profile(&profile_name) { let overwrite = Confirm::new() - .with_prompt(&format!("Profile '{}' already exists. Overwrite?", profile_name)) + .with_prompt(format!("Profile '{}' already exists. Overwrite?", profile_name)) .default(false) .interact()?; if !overwrite { @@ -575,7 +575,7 @@ impl ProfileCommand { let description: String = Input::new() .with_prompt("Description (optional)") - .default(format!("Imported from existing git config")) + .default("Imported from existing git config".to_string()) .allow_empty(true) .interact_text()?; @@ -849,7 +849,7 @@ impl ProfileCommand { println!("{}", "─".repeat(40)); let token_value: String = Input::new() - .with_prompt(&format!("Token for {}", service)) + .with_prompt(format!("Token for {}", service)) .interact_text()?; let token_type_options = vec!["Personal", "OAuth", "Deploy", "App"]; @@ -895,7 +895,7 @@ impl ProfileCommand { } let confirm = Confirm::new() - .with_prompt(&format!("Remove token '{}' from profile '{}'?", service, profile_name)) + .with_prompt(format!("Remove token '{}' from profile '{}'?", service, profile_name)) .default(false) .interact()?; diff --git a/src/commands/tag.rs b/src/commands/tag.rs index e77da25..d97377c 100644 --- a/src/commands/tag.rs +++ b/src/commands/tag.rs @@ -56,6 +56,10 @@ pub struct TagCommand { #[arg(long)] dry_run: bool, + /// Enable thinking mode for this tag (override config) + #[arg(short = 't', long)] + think: bool, + /// Skip interactive prompts #[arg(short = 'y', long)] yes: bool, @@ -285,7 +289,7 @@ impl TagCommand { println!("{}", messages.ai_generating_tag(commits.len())); - let generator = ContentGenerator::new(&manager).await?; + let generator = ContentGenerator::new_with_think(&manager, self.think).await?; generator.generate_tag_message(version, &commits, language).await } diff --git a/src/config/manager.rs b/src/config/manager.rs index 9ed317b..8b3822d 100644 --- a/src/config/manager.rs +++ b/src/config/manager.rs @@ -136,11 +136,10 @@ impl ConfigManager { /// Set default profile pub fn set_default_profile(&mut self, name: Option) -> Result<()> { - if let Some(ref n) = name { - if !self.config.profiles.contains_key(n) { + if let Some(ref n) = name + && !self.config.profiles.contains_key(n) { bail!("Profile '{}' does not exist", n); } - } self.config.default_profile = name; self.modified = true; Ok(()) diff --git a/src/config/mod.rs b/src/config/mod.rs index b39a401..38383fa 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -111,9 +111,13 @@ pub struct LlmConfig { #[serde(default)] pub api_key: Option, - /// Enable thinking/reasoning mode (deepseek, kimi) + /// Enable thinking/reasoning mode (deepseek, kimi, anthropic) #[serde(default)] pub thinking_enabled: bool, + + /// Budget tokens for thinking mode (Anthropic Claude 4) + #[serde(default)] + pub thinking_budget_tokens: Option, } fn default_api_key_storage() -> String { @@ -132,6 +136,7 @@ impl Default for LlmConfig { api_key_storage: default_api_key_storage(), api_key: None, thinking_enabled: false, + thinking_budget_tokens: None, } } } diff --git a/src/config/profile.rs b/src/config/profile.rs index 76c5f3d..61fa6c5 100644 --- a/src/config/profile.rs +++ b/src/config/profile.rs @@ -119,9 +119,7 @@ impl GitProfile { /// Get signing key (from GPG config or direct) pub fn signing_key(&self) -> Option<&str> { - self.signing_key - .as_ref() - .map(|s| s.as_str()) + self.signing_key.as_deref() .or_else(|| self.gpg.as_ref().map(|g| g.key_id.as_str())) } @@ -175,8 +173,8 @@ impl GitProfile { } } - if let Some(ref ssh) = self.ssh { - if let Some(ref key_path) = ssh.private_key_path { + if let Some(ref ssh) = self.ssh + && let Some(ref key_path) = ssh.private_key_path { let path_str = key_path.display().to_string(); #[cfg(target_os = "windows")] { @@ -189,7 +187,6 @@ impl GitProfile { &format!("ssh -i '{}'", path_str))?; } } - } Ok(()) } @@ -213,8 +210,8 @@ impl GitProfile { } } - if let Some(ref ssh) = self.ssh { - if let Some(ref key_path) = ssh.private_key_path { + if let Some(ref ssh) = self.ssh + && let Some(ref key_path) = ssh.private_key_path { let path_str = key_path.display().to_string(); #[cfg(target_os = "windows")] { @@ -227,7 +224,6 @@ impl GitProfile { &format!("ssh -i '{}'", path_str))?; } } - } Ok(()) } @@ -264,8 +260,8 @@ impl GitProfile { }); } - if let Some(profile_key) = self.signing_key() { - if git_signing_key.as_deref() != Some(profile_key) { + if let Some(profile_key) = self.signing_key() + && git_signing_key.as_deref() != Some(profile_key) { comparison.matches = false; comparison.differences.push(ConfigDifference { key: "user.signingkey".to_string(), @@ -273,7 +269,6 @@ impl GitProfile { git_value: git_signing_key.unwrap_or_else(|| "".to_string()), }); } - } Ok(comparison) } @@ -281,6 +276,7 @@ impl GitProfile { /// Profile settings #[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Default)] pub struct ProfileSettings { /// Automatically sign commits #[serde(default)] @@ -307,18 +303,6 @@ pub struct ProfileSettings { pub commit_template: Option, } -impl Default for ProfileSettings { - fn default() -> Self { - Self { - auto_sign_commits: false, - auto_sign_tags: false, - default_commit_format: None, - repo_patterns: vec![], - llm_provider: None, - commit_template: None, - } - } -} /// SSH configuration #[derive(Debug, Clone, Serialize, Deserialize)] @@ -349,17 +333,15 @@ pub struct SshConfig { impl SshConfig { /// Validate SSH configuration pub fn validate(&self) -> Result<()> { - if let Some(ref path) = self.private_key_path { - if !path.exists() { + if let Some(ref path) = self.private_key_path + && !path.exists() { bail!("SSH private key does not exist: {:?}", path); } - } - if let Some(ref path) = self.public_key_path { - if !path.exists() { + if let Some(ref path) = self.public_key_path + && !path.exists() { bail!("SSH public key does not exist: {:?}", path); } - } Ok(()) } @@ -495,7 +477,9 @@ impl TokenConfig { /// Token type #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "lowercase")] +#[derive(Default)] pub enum TokenType { + #[default] None, Personal, OAuth, @@ -503,11 +487,6 @@ pub enum TokenType { App, } -impl Default for TokenType { - fn default() -> Self { - Self::None - } -} impl std::fmt::Display for TokenType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/src/generator/mod.rs b/src/generator/mod.rs index 32de3cc..c9ddd11 100644 --- a/src/generator/mod.rs +++ b/src/generator/mod.rs @@ -12,15 +12,43 @@ pub struct ContentGenerator { impl ContentGenerator { /// Create new content generator pub async fn new(manager: &ConfigManager) -> Result { - let llm_client = LlmClient::from_config(manager).await?; - + Self::new_with_think(manager, false).await + } + + /// Create new content generator with thinking override + pub async fn new_with_think(manager: &ConfigManager, think_override: bool) -> Result { + let mut thinking_enabled = if think_override { + true + } else { + manager.config().llm.thinking_enabled + }; + + // Validate thinking support per provider + if thinking_enabled { + let provider = manager.llm_provider(); + if !Self::supports_thinking(provider) { + eprintln!( + "Warning: Provider '{}' does not support thinking mode. \ + Disabling thinking for this invocation.", + provider + ); + thinking_enabled = false; + } + } + + let llm_client = LlmClient::from_config_with_think(manager, thinking_enabled).await?; + if !llm_client.is_available().await { anyhow::bail!("LLM provider '{}' is not available", manager.llm_provider()); } - + Ok(Self { llm_client }) } + fn supports_thinking(provider: &str) -> bool { + matches!(provider, "deepseek" | "kimi" | "anthropic" | "openai") + } + /// Generate commit message from diff pub async fn generate_commit_message( &self, diff --git a/src/git/changelog.rs b/src/git/changelog.rs index 2cdae8c..09e351d 100644 --- a/src/git/changelog.rs +++ b/src/git/changelog.rs @@ -234,7 +234,7 @@ impl ChangelogGenerator { _date: DateTime, commits: &[CommitInfo], ) -> Result { - let mut output = format!("## What's Changed\n\n"); + let mut output = "## What's Changed\n\n".to_string(); // Group by type let mut features = vec![]; @@ -427,19 +427,16 @@ pub fn parse_versions(changelog: &str) -> Vec<(String, String)> { let mut versions = vec![]; for line in changelog.lines() { - if line.starts_with("## [") { - if let Some(start) = line.find('[') { - if let Some(end) = line.find(']') { + if line.starts_with("## [") + && let Some(start) = line.find('[') + && let Some(end) = line.find(']') { let version = &line[start + 1..end]; - if version != "Unreleased" { - if let Some(date_start) = line.find(" - ") { + if version != "Unreleased" + && let Some(date_start) = line.find(" - ") { let date = &line[date_start + 3..].trim(); versions.push((version.to_string(), date.to_string())); } - } } - } - } } versions diff --git a/src/git/mod.rs b/src/git/mod.rs index 2c0cbe3..ec865f6 100644 --- a/src/git/mod.rs +++ b/src/git/mod.rs @@ -2,7 +2,6 @@ use anyhow::{bail, Context, Result}; use git2::{Repository, Signature, StatusOptions, Config, Oid, ObjectType}; use std::path::{Path, PathBuf, Component}; use std::collections::HashMap; -use tempfile; pub mod changelog; pub mod commit; @@ -15,16 +14,14 @@ fn normalize_path_for_git2(path: &Path) -> PathBuf { #[cfg(target_os = "windows")] { let path_str = path.to_string_lossy(); - if path_str.starts_with(r"\\?\") { - if let Some(stripped) = path_str.strip_prefix(r"\\?\") { + if path_str.starts_with(r"\\?\") + && let Some(stripped) = path_str.strip_prefix(r"\\?\") { normalized = PathBuf::from(stripped); } - } - if path_str.starts_with(r"\\?\UNC\") { - if let Some(stripped) = path_str.strip_prefix(r"\\?\UNC\") { + if path_str.starts_with(r"\\?\UNC\") + && let Some(stripped) = path_str.strip_prefix(r"\\?\UNC\") { normalized = PathBuf::from(format!(r"\\{}", stripped)); } - } } normalized @@ -75,7 +72,7 @@ fn try_open_repo_with_git2(path: &Path) -> Result { let discover_opts = git2::RepositoryOpenFlags::empty(); let ceiling_dirs: [&str; 0] = []; - let repo = Repository::open_ext(&normalized, discover_opts, &ceiling_dirs) + let repo = Repository::open_ext(&normalized, discover_opts, ceiling_dirs) .or_else(|_| Repository::discover(&normalized)) .or_else(|_| Repository::open(&normalized)); @@ -84,7 +81,7 @@ fn try_open_repo_with_git2(path: &Path) -> Result { fn try_open_repo_with_git_cli(path: &Path) -> Result { let output = std::process::Command::new("git") - .args(&["rev-parse", "--show-toplevel"]) + .args(["rev-parse", "--show-toplevel"]) .current_dir(path) .output() .context("Failed to execute git command")?; @@ -330,7 +327,7 @@ impl GitRepo { pub fn get_staged_diff(&self) -> Result { // Use git CLI to get staged diff for better compatibility let output = std::process::Command::new("git") - .args(&["diff", "--cached"]) + .args(["diff", "--cached"]) .current_dir(&self.path) .output() .with_context(|| "Failed to get staged diff with git command")?; @@ -509,11 +506,10 @@ impl GitRepo { let mut files = vec![]; for entry in statuses.iter() { let status = entry.status(); - if status.is_index_new() || status.is_index_modified() || status.is_index_deleted() || status.is_index_renamed() || status.is_index_typechange() { - if let Some(path) = entry.path() { + if (status.is_index_new() || status.is_index_modified() || status.is_index_deleted() || status.is_index_renamed() || status.is_index_typechange()) + && let Some(path) = entry.path() { files.push(path.to_string()); } - } } Ok(files) @@ -542,7 +538,7 @@ impl GitRepo { pub fn stage_all(&self) -> Result<()> { // Use git command for reliable staging (handles all edge cases) let output = std::process::Command::new("git") - .args(&["add", "-A"]) + .args(["add", "-A"]) .current_dir(&self.path) .output() .with_context(|| "Failed to stage changes with git command")?; @@ -626,7 +622,7 @@ impl GitRepo { std::fs::write(temp_file.path(), message)?; let output = std::process::Command::new("git") - .args(&["commit", "-S", "-F", temp_file.path().to_str().unwrap()]) + .args(["commit", "-S", "-F", temp_file.path().to_str().unwrap()]) .current_dir(&self.path) .output()?; @@ -801,7 +797,7 @@ impl GitRepo { /// Create signed tag using git CLI fn create_signed_tag_with_git2(&self, name: &str, message: &str, _signature: &Signature, _target_id: Oid) -> Result<()> { let output = std::process::Command::new("git") - .args(&["tag", "-s", name, "-m", message]) + .args(["tag", "-s", name, "-m", message]) .current_dir(&self.path) .output()?; @@ -827,7 +823,7 @@ impl GitRepo { /// Push to remote pub fn push(&self, remote: &str, refspec: &str) -> Result<()> { let output = std::process::Command::new("git") - .args(&["push", remote, refspec]) + .args(["push", remote, refspec]) .current_dir(&self.path) .output()?; @@ -856,7 +852,7 @@ impl GitRepo { pub fn status_summary(&self) -> Result { // Use git CLI for more reliable status detection let output = std::process::Command::new("git") - .args(&["status", "--porcelain"]) + .args(["status", "--porcelain"]) .current_dir(&self.path) .output() .with_context(|| "Failed to get status with git command")?; @@ -1015,20 +1011,17 @@ pub fn find_repo>(start_path: P) -> Result { } if let Ok(output) = std::process::Command::new("git") - .args(&["rev-parse", "--show-toplevel"]) + .args(["rev-parse", "--show-toplevel"]) .current_dir(&resolved_start) .output() - { - if output.status.success() { + && output.status.success() { let stdout = String::from_utf8_lossy(&output.stdout); let git_root = stdout.trim(); - if !git_root.is_empty() { - if let Ok(repo) = GitRepo::open(git_root) { + if !git_root.is_empty() + && let Ok(repo) = GitRepo::open(git_root) { return Ok(repo); } - } } - } let diagnosis = diagnose_repo_issue(&resolved_start); diff --git a/src/git/tag.rs b/src/git/tag.rs index bab8d5a..91dc5e2 100644 --- a/src/git/tag.rs +++ b/src/git/tag.rs @@ -283,7 +283,7 @@ pub fn delete_tag(repo: &GitRepo, name: &str, remote: Option<&str>) -> Result<() let refspec = format!(":refs/tags/{}", name); let output = Command::new("git") - .args(&["push", remote, &refspec]) + .args(["push", remote, &refspec]) .current_dir(repo.path()) .output()?; diff --git a/src/llm/anthropic.rs b/src/llm/anthropic.rs index 8627038..fc30352 100644 --- a/src/llm/anthropic.rs +++ b/src/llm/anthropic.rs @@ -1,7 +1,9 @@ +use super::thinking::ThinkingStateManager; use super::{create_http_client, LlmProvider}; use anyhow::{bail, Context, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; +use std::sync::Arc; use std::time::Duration; /// Anthropic Claude API client @@ -9,6 +11,12 @@ pub struct AnthropicClient { api_key: String, model: String, client: reqwest::Client, + thinking_enabled: bool, + thinking_budget_tokens: u32, + max_tokens: u32, + temperature: f32, + top_p: Option, + thinking_state: Option>, } #[derive(Debug, Serialize)] @@ -17,24 +25,58 @@ struct MessagesRequest { max_tokens: u32, #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + top_p: Option, messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] - system: Option, + system: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + thinking: Option, + stream: bool, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Clone)] +struct SystemContent { + #[serde(rename = "type")] + content_type: String, + text: String, +} + +#[derive(Debug, Serialize)] +struct ThinkingConfig { + #[serde(rename = "type")] + thinking_type: String, + budget_tokens: u32, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] struct AnthropicMessage { role: String, - content: String, + content: AnthropicContent, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(untagged)] +enum AnthropicContent { + Text(String), + Blocks(Vec), +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct ContentBlock { + #[serde(rename = "type")] + content_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + text: Option, } #[derive(Debug, Deserialize)] struct MessagesResponse { - content: Vec, + content: Vec, } #[derive(Debug, Deserialize)] -struct ContentBlock { +struct ResponseContentBlock { #[serde(rename = "type")] content_type: String, text: String, @@ -52,31 +94,112 @@ struct AnthropicError { message: String, } +// --- Streaming SSE event structures --- + +#[derive(Debug, Deserialize)] +struct SseEvent { + #[serde(rename = "type")] + event_type: String, + #[serde(default)] + message: Option, + #[serde(default)] + index: Option, + #[serde(default)] + content_block: Option, + #[serde(default)] + delta: Option, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct SseMessage { + #[serde(default)] + content: Option>, +} + +#[derive(Debug, Deserialize)] +struct SseContentBlock { + #[serde(rename = "type")] + content_type: String, + #[serde(default)] + thinking: Option, + #[serde(default)] + text: Option, +} + +#[derive(Debug, Deserialize)] +struct SseDelta { + #[serde(rename = "type")] + delta_type: Option, + #[serde(default)] + thinking: Option, + #[serde(default)] + text: Option, +} + +#[derive(Debug, Deserialize)] +struct SseUsage { + #[serde(default)] + output_tokens: Option, +} + impl AnthropicClient { - /// Create new Anthropic client pub fn new(api_key: &str, model: &str) -> Result { let client = create_http_client(Duration::from_secs(60))?; - + Ok(Self { api_key: api_key.to_string(), model: model.to_string(), client, + thinking_enabled: false, + thinking_budget_tokens: 1024, + max_tokens: 500, + temperature: 0.7, + top_p: None, + thinking_state: None, }) } - /// Set timeout pub fn with_timeout(mut self, timeout: Duration) -> Result { self.client = create_http_client(timeout)?; Ok(self) } - /// List available models + pub fn with_thinking(mut self, enabled: bool) -> Self { + self.thinking_enabled = enabled; + self + } + + pub fn with_thinking_budget_tokens(mut self, budget_tokens: u32) -> Self { + self.thinking_budget_tokens = budget_tokens; + self + } + + pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = max_tokens; + self + } + + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.temperature = temperature; + self + } + + pub fn with_top_p(mut self, top_p: f32) -> Self { + self.top_p = Some(top_p); + self + } + + pub fn with_thinking_state(mut self, state: Arc) -> Self { + self.thinking_state = Some(state); + self + } + pub async fn list_models(&self) -> Result> { - // Anthropic doesn't have a models API endpoint, return predefined list Ok(ANTHROPIC_MODELS.iter().map(|&m| m.to_string()).collect()) } - /// Validate API key pub async fn validate_key(&self) -> Result { let url = "https://api.anthropic.com/v1/messages"; @@ -84,14 +207,18 @@ impl AnthropicClient { model: self.model.clone(), max_tokens: 5, temperature: Some(0.0), + top_p: None, messages: vec![AnthropicMessage { role: "user".to_string(), - content: "Hi".to_string(), + content: AnthropicContent::Text("Hi".to_string()), }], system: None, + thinking: None, + stream: false, }; - let response = self.client + let response = self + .client .post(url) .header("x-api-key", &self.api_key) .header("anthropic-version", "2023-06-01") @@ -124,25 +251,28 @@ impl LlmProvider for AnthropicClient { async fn generate(&self, prompt: &str) -> Result { let messages = vec![AnthropicMessage { role: "user".to_string(), - content: prompt.to_string(), + content: AnthropicContent::Text(prompt.to_string()), }]; - - self.messages_request(messages, None).await + + self.messages_request_with_retry(messages, None).await } async fn generate_with_system(&self, system: &str, user: &str) -> Result { let messages = vec![AnthropicMessage { role: "user".to_string(), - content: user.to_string(), + content: AnthropicContent::Text(user.to_string()), }]; - + let system = if system.is_empty() { None } else { - Some(system.to_string()) + Some(vec![SystemContent { + content_type: "text".to_string(), + text: system.to_string(), + }]) }; - - self.messages_request(messages, system).await + + self.messages_request_with_retry(messages, system).await } async fn is_available(&self) -> bool { @@ -155,22 +285,81 @@ impl LlmProvider for AnthropicClient { } impl AnthropicClient { + async fn messages_request_with_retry( + &self, + messages: Vec, + system: Option>, + ) -> Result { + let mut last_error = None; + + for attempt in 1..=3 { + match self + .messages_request(messages.clone(), system.clone()) + .await + { + Ok(result) => return Ok(result), + Err(e) => { + let err_msg = e.to_string(); + let is_retryable = err_msg.contains("timeout") + || err_msg.contains("connection") + || err_msg.contains("temporary") + || err_msg.contains("5") + && (err_msg.contains("500") + || err_msg.contains("502") + || err_msg.contains("503") + || err_msg.contains("504")); + + if !is_retryable || attempt == 3 { + last_error = Some(e); + break; + } + + tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await; + } + } + } + + Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries"))) + } + async fn messages_request( &self, messages: Vec, - system: Option, + system: Option>, + ) -> Result { + if self.thinking_enabled { + self.streaming_messages_request(messages, system).await + } else { + self.non_streaming_messages_request(messages, system).await + } + } + + async fn non_streaming_messages_request( + &self, + messages: Vec, + system: Option>, ) -> Result { let url = "https://api.anthropic.com/v1/messages"; - + + let temperature = if self.temperature == 0.0 { + None + } else { + Some(self.temperature) + }; + let request = MessagesRequest { model: self.model.clone(), - max_tokens: 500, - temperature: Some(0.7), + max_tokens: self.max_tokens, + temperature, + top_p: self.top_p, messages, system, + thinking: None, + stream: false, }; - - let response = self.client + + let response = self + .client .post(url) .header("x-api-key", &self.api_key) .header("anthropic-version", "2023-06-01") @@ -179,35 +368,205 @@ impl AnthropicClient { .send() .await .context("Failed to send request to Anthropic")?; - + let status = response.status(); - + if !status.is_success() { let text = response.text().await.unwrap_or_default(); - - // Try to parse error + if let Ok(error) = serde_json::from_str::(&text) { - bail!("Anthropic API error: {} ({})", error.error.message, error.error.error_type); + bail!( + "Anthropic API error: {} ({})", + error.error.message, + error.error.error_type + ); } - + bail!("Anthropic API error: {} - {}", status, text); } - + let result: MessagesResponse = response .json() .await .context("Failed to parse Anthropic response")?; - - result.content + + result + .content .into_iter() .find(|c| c.content_type == "text") .map(|c| c.text.trim().to_string()) + .filter(|s| !s.is_empty()) .ok_or_else(|| anyhow::anyhow!("No text response from Anthropic")) } + + /// Streaming request for thinking mode, filters thinking content blocks + async fn streaming_messages_request( + &self, + messages: Vec, + system: Option>, + ) -> Result { + let url = "https://api.anthropic.com/v1/messages"; + + let thinking = ThinkingConfig { + thinking_type: "enabled".to_string(), + budget_tokens: self.thinking_budget_tokens, + }; + + // max_tokens must exceed budget_tokens + let max_tokens = (self.max_tokens).max(self.thinking_budget_tokens + 100); + + let request = MessagesRequest { + model: self.model.clone(), + max_tokens, + temperature: None, // must be omitted for thinking mode + top_p: None, + messages, + system, + thinking: Some(thinking), + stream: true, + }; + + let response = self + .client + .post(url) + .header("x-api-key", &self.api_key) + .header("anthropic-version", "2023-06-01") + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream") + .json(&request) + .send() + .await + .context("Failed to send streaming request to Anthropic")?; + + let status = response.status(); + + if !status.is_success() { + let text = response.text().await.unwrap_or_default(); + + if let Ok(error) = serde_json::from_str::(&text) { + bail!( + "Anthropic API error: {} ({})", + error.error.message, + error.error.error_type + ); + } + + bail!("Anthropic API error: {} - {}", status, text); + } + + let mut content_buffer = String::new(); + let mut in_thinking = false; + let mut has_reasoning = false; + let mut has_content = false; + + let thinking_state = self.thinking_state.as_ref(); + + let mut byte_stream = response.bytes_stream(); + let mut line_buffer = String::new(); + + use futures_util::StreamExt; + + while let Some(chunk) = byte_stream.next().await { + let chunk = chunk.context("Failed to read streaming response chunk")?; + let chunk_str = + String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?; + + line_buffer.push_str(&chunk_str); + + while let Some(line_end) = line_buffer.find('\n') { + let line = line_buffer[..line_end].trim().to_string(); + line_buffer = line_buffer[line_end + 1..].to_string(); + + if line.is_empty() { + continue; + } + + // Parse SSE event line + if let Some(data) = line.strip_prefix("data: ") { + if let Ok(event) = serde_json::from_str::(data) { + match event.event_type.as_str() { + "content_block_start" => { + if let Some(ref block) = event.content_block { + if block.content_type == "thinking" { + in_thinking = true; + if !has_reasoning { + has_reasoning = true; + if let Some(state) = thinking_state { + state.start_thinking(); + } + } + } + } + } + "content_block_delta" => { + if let Some(ref delta) = event.delta { + // Thinking delta - ignore content but track state + if delta.thinking.is_some() { + continue; + } + + // Text delta - collect + if in_thinking && delta.text.is_some() { + // Transition from thinking to text + if let Some(state) = thinking_state { + state.end_thinking(); + } + in_thinking = false; + } + if let Some(ref text) = delta.text + && !text.is_empty() + { + has_content = true; + content_buffer.push_str(text); + } + } + } + "content_block_stop" => { + if in_thinking { + if let Some(state) = thinking_state { + state.end_thinking(); + } + in_thinking = false; + } + } + _ => {} + } + } + } + } + } + + // Ensure thinking state is ended + if let Some(state) = thinking_state { + state.end_thinking(); + } + + let result = content_buffer.trim().to_string(); + + if result.is_empty() { + if has_reasoning && !has_content { + bail!( + "Anthropic returned thinking content but no final answer. \ + The model may have entered an incomplete thinking state. \ + Please try again or disable thinking mode." + ); + } + bail!( + "No response from Anthropic. \ + If thinking mode is enabled, try disabling it or ensure the model supports it." + ); + } + + Ok(result) + } } -/// Available Anthropic models +/// Available Anthropic models (Claude 4 series with extended thinking) pub const ANTHROPIC_MODELS: &[&str] = &[ + "claude-opus-4-7", + "claude-sonnet-4-6", + "claude-haiku-4-5", + // Legacy models "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", @@ -216,7 +575,6 @@ pub const ANTHROPIC_MODELS: &[&str] = &[ "claude-instant-1.2", ]; -/// Check if a model name is valid pub fn is_valid_model(model: &str) -> bool { ANTHROPIC_MODELS.contains(&model) } @@ -226,8 +584,61 @@ mod tests { use super::*; #[test] - fn test_model_validation() { + fn test_model_validation_claude4() { + assert!(is_valid_model("claude-opus-4-7")); + assert!(is_valid_model("claude-sonnet-4-6")); + assert!(is_valid_model("claude-haiku-4-5")); assert!(is_valid_model("claude-3-sonnet-20240229")); assert!(!is_valid_model("invalid-model")); } + + #[test] + fn test_thinking_config_serialization() { + let config = ThinkingConfig { + thinking_type: "enabled".to_string(), + budget_tokens: 2048, + }; + let json = serde_json::to_string(&config).unwrap(); + assert!(json.contains(r#""type":"enabled""#)); + assert!(json.contains(r#""budget_tokens":2048"#)); + } + + #[test] + fn test_system_content_serialization() { + let content = SystemContent { + content_type: "text".to_string(), + text: "You are helpful.".to_string(), + }; + let json = serde_json::to_string(&content).unwrap(); + assert!(json.contains(r#""type":"text""#)); + } + + #[test] + fn test_sse_event_parsing_content_block_start() { + let json = r#"{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}"#; + let event: SseEvent = serde_json::from_str(json).unwrap(); + assert_eq!(event.event_type, "content_block_start"); + assert_eq!( + event.content_block.unwrap().content_type, + "thinking" + ); + } + + #[test] + fn test_sse_event_parsing_text_delta() { + let json = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#; + let event: SseEvent = serde_json::from_str(json).unwrap(); + assert_eq!(event.event_type, "content_block_delta"); + assert_eq!(event.delta.unwrap().text, Some("Hello".to_string())); + } + + #[test] + fn test_anthropic_content_text() { + let msg = AnthropicMessage { + role: "user".to_string(), + content: AnthropicContent::Text("Hello".to_string()), + }; + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains(r#""content":"Hello""#)); + } } diff --git a/src/llm/deepseek.rs b/src/llm/deepseek.rs index 30677cf..082eb61 100644 --- a/src/llm/deepseek.rs +++ b/src/llm/deepseek.rs @@ -1,7 +1,9 @@ +use super::thinking::ThinkingStateManager; use super::{create_http_client, LlmProvider}; use anyhow::{bail, Context, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; +use std::sync::Arc; use std::time::Duration; /// DeepSeek API client @@ -11,6 +13,10 @@ pub struct DeepSeekClient { model: String, client: reqwest::Client, thinking_enabled: bool, + reasoning_effort: Option, + max_tokens: u32, + temperature: f32, + thinking_state: Option>, } #[derive(Debug, Serialize)] @@ -21,9 +27,17 @@ struct ChatCompletionRequest { max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + frequency_penalty: Option, stream: bool, #[serde(skip_serializing_if = "Option::is_none")] thinking: Option, + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_effort: Option, } #[derive(Debug, Serialize)] @@ -32,10 +46,12 @@ struct ThinkingConfig { thinking_type: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct Message { role: String, content: String, + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_content: Option, } #[derive(Debug, Deserialize)] @@ -50,6 +66,29 @@ struct Choice { reasoning_content: Option, } +// --- Streaming response structures --- + +#[derive(Debug, Deserialize)] +struct StreamChunk { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct StreamChoice { + delta: StreamDelta, + #[serde(default)] + finish_reason: Option, + index: Option, +} + +#[derive(Debug, Deserialize, Default)] +struct StreamDelta { + #[serde(default)] + content: Option, + #[serde(default)] + reasoning_content: Option, +} + #[derive(Debug, Deserialize)] struct ErrorResponse { error: ApiError, @@ -63,22 +102,24 @@ struct ApiError { } impl DeepSeekClient { - /// Create new DeepSeek client pub fn new(api_key: &str, model: &str) -> Result { - let client = create_http_client(Duration::from_secs(60))?; + let client = create_http_client(Duration::from_secs(300))?; Ok(Self { - base_url: "https://api.deepseek.com/".to_string(), + base_url: "https://api.deepseek.com".to_string(), api_key: api_key.to_string(), model: model.to_string(), client, thinking_enabled: false, + reasoning_effort: None, + max_tokens: 500, + temperature: 0.7, + thinking_state: None, }) } - /// Create with custom base URL pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result { - let client = create_http_client(Duration::from_secs(60))?; + let client = create_http_client(Duration::from_secs(300))?; Ok(Self { base_url: base_url.trim_end_matches('/').to_string(), @@ -86,26 +127,48 @@ impl DeepSeekClient { model: model.to_string(), client, thinking_enabled: false, + reasoning_effort: None, + max_tokens: 500, + temperature: 0.7, + thinking_state: None, }) } - /// Set timeout pub fn with_timeout(mut self, timeout: Duration) -> Result { self.client = create_http_client(timeout)?; Ok(self) } - /// Enable or disable thinking mode pub fn with_thinking(mut self, enabled: bool) -> Self { self.thinking_enabled = enabled; self } - /// List available models + pub fn with_reasoning_effort(mut self, effort: Option) -> Self { + self.reasoning_effort = effort; + self + } + + pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = max_tokens; + self + } + + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.temperature = temperature; + self + } + + pub fn with_thinking_state(mut self, state: Arc) -> Self { + self.thinking_state = Some(state); + self + } + pub async fn list_models(&self) -> Result> { let url = format!("{}/models", self.base_url); - let response = self.client + let response = self + .client .get(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .send() @@ -120,11 +183,11 @@ impl DeepSeekClient { #[derive(Deserialize)] struct ModelsResponse { - data: Vec, + data: Vec, } #[derive(Deserialize)] - struct Model { + struct ModelId { id: String, } @@ -136,7 +199,6 @@ impl DeepSeekClient { Ok(result.data.into_iter().map(|m| m.id).collect()) } - /// Validate API key pub async fn validate_key(&self) -> Result { match self.list_models().await { Ok(_) => Ok(true), @@ -155,14 +217,13 @@ impl DeepSeekClient { #[async_trait] impl LlmProvider for DeepSeekClient { async fn generate(&self, prompt: &str) -> Result { - let messages = vec![ - Message { - role: "user".to_string(), - content: prompt.to_string(), - }, - ]; + let messages = vec![Message { + role: "user".to_string(), + content: prompt.to_string(), + reasoning_content: None, + }]; - self.chat_completion(messages).await + self.chat_completion_with_retry(messages).await } async fn generate_with_system(&self, system: &str, user: &str) -> Result { @@ -172,15 +233,17 @@ impl LlmProvider for DeepSeekClient { messages.push(Message { role: "system".to_string(), content: system.to_string(), + reasoning_content: None, }); } messages.push(Message { role: "user".to_string(), content: user.to_string(), + reasoning_content: None, }); - self.chat_completion(messages).await + self.chat_completion_with_retry(messages).await } async fn is_available(&self) -> bool { @@ -193,6 +256,38 @@ impl LlmProvider for DeepSeekClient { } impl DeepSeekClient { + async fn chat_completion_with_retry(&self, messages: Vec) -> Result { + let mut last_error = None; + + for attempt in 1..=3 { + match self.chat_completion(messages.clone()).await { + Ok(result) => return Ok(result), + Err(e) => { + let err_msg = e.to_string(); + // 网络临时错误才重试 + let is_retryable = err_msg.contains("timeout") + || err_msg.contains("connection") + || err_msg.contains("temporary") + || err_msg.contains("5") + && (err_msg.contains("500") + || err_msg.contains("502") + || err_msg.contains("503") + || err_msg.contains("504")); + + if !is_retryable || attempt == 3 { + last_error = Some(e); + break; + } + + // 指数退避 + tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await; + } + } + } + + Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries"))) + } + async fn chat_completion(&self, messages: Vec) -> Result { let url = format!("{}/chat/completions", self.base_url); @@ -204,20 +299,59 @@ impl DeepSeekClient { None }; - let request = ChatCompletionRequest { - model: self.model.clone(), - messages, - max_tokens: Some(500), - temperature: Some(0.7), - stream: false, - thinking, + // 思考模式下,temperature/top_p 等参数不应传递 + // 非思考模式下可以正常传递 + let (temperature, max_tokens, top_p, presence_penalty, frequency_penalty) = + if self.thinking_enabled { + (None, Some(self.max_tokens), None, None, None) + } else { + ( + Some(self.temperature), + Some(self.max_tokens), + None, + None, + None, + ) + }; + + let reasoning_effort = if self.thinking_enabled { + self.reasoning_effort.clone() + } else { + None }; - let response = self.client - .post(&url) + let request = ChatCompletionRequest { + model: self.model.clone(), + messages: messages.clone(), + max_tokens, + temperature, + top_p, + presence_penalty, + frequency_penalty, + stream: self.thinking_enabled, + thinking, + reasoning_effort, + }; + + if self.thinking_enabled { + self.streaming_chat_completion(&url, &request).await + } else { + self.non_streaming_chat_completion(&url, &request).await + } + } + + /// 非流式请求(非思考模式) + async fn non_streaming_chat_completion( + &self, + url: &str, + request: &ChatCompletionRequest, + ) -> Result { + let response = self + .client + .post(url) .header("Authorization", format!("Bearer {}", self.api_key)) .header("Content-Type", "application/json") - .json(&request) + .json(request) .send() .await .context("Failed to send request to DeepSeek")?; @@ -228,7 +362,11 @@ impl DeepSeekClient { let text = response.text().await.unwrap_or_default(); if let Ok(error) = serde_json::from_str::(&text) { - bail!("DeepSeek API error: {} ({})", error.error.message, error.error.error_type); + bail!( + "DeepSeek API error: {} ({})", + error.error.message, + error.error.error_type + ); } bail!("DeepSeek API error: {} - {}", status, text); @@ -239,34 +377,165 @@ impl DeepSeekClient { .await .context("Failed to parse DeepSeek response")?; - result.choices + result + .choices .into_iter() .next() - .map(|c| { - let content = c.message.content.trim().to_string(); - if content.is_empty() { - c.reasoning_content - .map(|r| r.trim().to_string()) - .unwrap_or_default() - } else { - content - } - }) + .map(|c| c.message.content.trim().to_string()) .filter(|s| !s.is_empty()) - .ok_or_else(|| anyhow::anyhow!( + .ok_or_else(|| anyhow::anyhow!("No response from DeepSeek")) + } + + /// 流式请求(思考模式),处理 reasoning_content 和 content + async fn streaming_chat_completion( + &self, + url: &str, + request: &ChatCompletionRequest, + ) -> Result { + let response = self + .client + .post(url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream") + .json(request) + .send() + .await + .context("Failed to send streaming request to DeepSeek")?; + + let status = response.status(); + + if !status.is_success() { + let text = response.text().await.unwrap_or_default(); + + if let Ok(error) = serde_json::from_str::(&text) { + bail!( + "DeepSeek API error: {} ({})", + error.error.message, + error.error.error_type + ); + } + + bail!("DeepSeek API error: {} - {}", status, text); + } + + let mut content_buffer = String::new(); + let mut has_reasoning = false; + let mut has_content = false; + let mut stream_ended = false; + + let thinking_state = self.thinking_state.as_ref(); + + let mut byte_stream = response.bytes_stream(); + let mut line_buffer = String::new(); + + use futures_util::StreamExt; + + while let Some(chunk) = byte_stream.next().await { + let chunk = chunk.context("Failed to read streaming response chunk")?; + let chunk_str = + String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?; + + line_buffer.push_str(&chunk_str); + + // 处理完整行 + while let Some(line_end) = line_buffer.find('\n') { + let line = line_buffer[..line_end].trim().to_string(); + line_buffer = line_buffer[line_end + 1..].to_string(); + + if line.is_empty() { + continue; + } + + // SSE 格式:data: {...} 或 data: [DONE] + if line == "data: [DONE]" { + stream_ended = true; + break; + } + + if let Some(json_str) = line.strip_prefix("data: ") { + match serde_json::from_str::(json_str) { + Ok(chunk) => { + for choice in &chunk.choices { + // 处理 reasoning_content + if let Some(ref reasoning) = choice.delta.reasoning_content + && !reasoning.is_empty() { + if !has_reasoning { + has_reasoning = true; + if let Some(state) = thinking_state { + state.start_thinking(); + } + } + // reasoning_content 不对外输出,仅用于内部状态判断 + continue; + } + + // 处理 content + if let Some(ref content) = choice.delta.content + && !content.is_empty() { + // reasoning 结束,content 开始出现时移除 thinking 标识 + if has_reasoning && !has_content + && let Some(state) = thinking_state { + state.end_thinking(); + } + has_content = true; + content_buffer.push_str(content); + } + + // 检查 finish_reason + if let Some(ref reason) = choice.finish_reason + && reason == "stop" { + stream_ended = true; + } + } + } + Err(_) => { + // 忽略无法解析的行(可能是心跳或注释) + } + } + } + } + + if stream_ended { + break; + } + } + + // 确保思考状态已结束 + if let Some(state) = thinking_state { + state.end_thinking(); + } + + let result = content_buffer.trim().to_string(); + + if result.is_empty() { + if has_reasoning && !has_content { + bail!( + "DeepSeek returned reasoning content but no final answer. \ + The model may have entered an incomplete thinking state. \ + Please try again or disable thinking mode." + ); + } + bail!( "No response from DeepSeek. \ If thinking mode is enabled, try disabling it or ensure the model supports it." - )) + ); + } + + Ok(result) } } -/// Available DeepSeek models +/// 可用 DeepSeek 模型列表 +/// deepseek-chat / deepseek-reasoner 将于 2026-07-24 停用,推荐使用 V4 系列 pub const DEEPSEEK_MODELS: &[&str] = &[ + "deepseek-v4-flash", + "deepseek-v4-pro", + // 兼容旧版模型 ID(将于 2026-07-24 停用) "deepseek-chat", "deepseek-reasoner", ]; -/// Check if a model name is valid pub fn is_valid_model(model: &str) -> bool { DEEPSEEK_MODELS.contains(&model) } @@ -276,9 +545,76 @@ mod tests { use super::*; #[test] - fn test_model_validation() { + fn test_model_validation_v4() { + assert!(is_valid_model("deepseek-v4-flash")); + assert!(is_valid_model("deepseek-v4-pro")); assert!(is_valid_model("deepseek-chat")); assert!(is_valid_model("deepseek-reasoner")); assert!(!is_valid_model("invalid-model")); + assert!(!is_valid_model("deepseek-v3")); + } + + #[test] + fn test_client_builder_defaults() { + let client = DeepSeekClient::new("test-key", "deepseek-v4-flash").unwrap(); + assert!(!client.thinking_enabled); + assert_eq!(client.max_tokens, 500); + assert_eq!(client.temperature, 0.7); + assert!(client.reasoning_effort.is_none()); + assert!(client.thinking_state.is_none()); + } + + #[test] + fn test_client_builder_with_thinking() { + let client = DeepSeekClient::new("test-key", "deepseek-v4-flash") + .unwrap() + .with_thinking(true) + .with_reasoning_effort(Some("high".to_string())) + .with_max_tokens(1000) + .with_temperature(0.5); + + assert!(client.thinking_enabled); + assert_eq!(client.reasoning_effort, Some("high".to_string())); + assert_eq!(client.max_tokens, 1000); + assert_eq!(client.temperature, 0.5); + } + + #[test] + fn test_thinking_config_serialization() { + let config = ThinkingConfig { + thinking_type: "enabled".to_string(), + }; + let json = serde_json::to_string(&config).unwrap(); + assert_eq!(json, r#"{"type":"enabled"}"#); + } + + #[test] + fn test_message_serialization_without_reasoning() { + let msg = Message { + role: "user".to_string(), + content: "Hello".to_string(), + reasoning_content: None, + }; + let json = serde_json::to_string(&msg).unwrap(); + assert!(!json.contains("reasoning_content")); + } + + #[test] + fn test_stream_delta_parsing() { + let json = r#"{"content":"Hello","reasoning_content":null}"#; + let delta: StreamDelta = serde_json::from_str(json).unwrap(); + assert_eq!(delta.content, Some("Hello".to_string())); + assert!(delta.reasoning_content.is_none()); + } + + #[test] + fn test_stream_delta_reasoning_only() { + let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#; + let delta: StreamDelta = serde_json::from_str(json).unwrap(); + assert!(delta.content.is_none()); + assert_eq!( + delta.reasoning_content, + Some("Let me think...".to_string()) + ); } } diff --git a/src/llm/kimi.rs b/src/llm/kimi.rs index 41c34b9..e8f15e4 100644 --- a/src/llm/kimi.rs +++ b/src/llm/kimi.rs @@ -1,284 +1,571 @@ -use super::{create_http_client, LlmProvider}; -use anyhow::{bail, Context, Result}; -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; -use std::time::Duration; - -/// Kimi API client (Moonshot AI) -pub struct KimiClient { - base_url: String, - api_key: String, - model: String, - client: reqwest::Client, - thinking_enabled: bool, -} - -#[derive(Debug, Serialize)] -struct ChatCompletionRequest { - model: String, - messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, - stream: bool, - #[serde(skip_serializing_if = "Option::is_none")] - thinking: Option, -} - -#[derive(Debug, Serialize)] -struct ThinkingConfig { - #[serde(rename = "type")] - thinking_type: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct Message { - role: String, - content: String, -} - -#[derive(Debug, Deserialize)] -struct ChatCompletionResponse { - choices: Vec, -} - -#[derive(Debug, Deserialize)] -struct Choice { - message: Message, - #[serde(default)] - reasoning_content: Option, -} - -#[derive(Debug, Deserialize)] -struct ErrorResponse { - error: ApiError, -} - -#[derive(Debug, Deserialize)] -struct ApiError { - message: String, - #[serde(rename = "type")] - error_type: String, -} - -impl KimiClient { - /// Create new Kimi client - pub fn new(api_key: &str, model: &str) -> Result { - let client = create_http_client(Duration::from_secs(60))?; - - Ok(Self { - base_url: "https://api.moonshot.cn/v1".to_string(), - api_key: api_key.to_string(), - model: model.to_string(), - client, - thinking_enabled: false, - }) - } - - /// Create with custom base URL - pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result { - let client = create_http_client(Duration::from_secs(60))?; - - Ok(Self { - base_url: base_url.trim_end_matches('/').to_string(), - api_key: api_key.to_string(), - model: model.to_string(), - client, - thinking_enabled: false, - }) - } - - /// Set timeout - pub fn with_timeout(mut self, timeout: Duration) -> Result { - self.client = create_http_client(timeout)?; - Ok(self) - } - - /// Enable or disable thinking mode - pub fn with_thinking(mut self, enabled: bool) -> Self { - self.thinking_enabled = enabled; - self - } - - /// List available models - pub async fn list_models(&self) -> Result> { - let url = format!("{}/models", self.base_url); - - let response = self.client - .get(&url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .send() - .await - .context("Failed to list Kimi models")?; - - if !response.status().is_success() { - let status = response.status(); - let text = response.text().await.unwrap_or_default(); - bail!("Kimi API error: {} - {}", status, text); - } - - #[derive(Deserialize)] - struct ModelsResponse { - data: Vec, - } - - #[derive(Deserialize)] - struct Model { - id: String, - } - - let result: ModelsResponse = response - .json() - .await - .context("Failed to parse Kimi response")?; - - Ok(result.data.into_iter().map(|m| m.id).collect()) - } - - /// Validate API key - pub async fn validate_key(&self) -> Result { - match self.list_models().await { - Ok(_) => Ok(true), - Err(e) => { - let err_str = e.to_string(); - if err_str.contains("401") || err_str.contains("Unauthorized") { - Ok(false) - } else { - Err(e) - } - } - } - } -} - -#[async_trait] -impl LlmProvider for KimiClient { - async fn generate(&self, prompt: &str) -> Result { - let messages = vec![ - Message { - role: "user".to_string(), - content: prompt.to_string(), - }, - ]; - - self.chat_completion(messages).await - } - - async fn generate_with_system(&self, system: &str, user: &str) -> Result { - let mut messages = vec![]; - - if !system.is_empty() { - messages.push(Message { - role: "system".to_string(), - content: system.to_string(), - }); - } - - messages.push(Message { - role: "user".to_string(), - content: user.to_string(), - }); - - self.chat_completion(messages).await - } - - async fn is_available(&self) -> bool { - self.validate_key().await.unwrap_or(false) - } - - fn name(&self) -> &str { - "kimi" - } -} - -impl KimiClient { - async fn chat_completion(&self, messages: Vec) -> Result { - let url = format!("{}/chat/completions", self.base_url); - - let thinking = if self.thinking_enabled { - Some(ThinkingConfig { - thinking_type: "enabled".to_string(), - }) - } else { - None - }; - - let request = ChatCompletionRequest { - model: self.model.clone(), - messages, - max_tokens: Some(500), - temperature: Some(1.0), - stream: false, - thinking, - }; - - let response = self.client - .post(&url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .json(&request) - .send() - .await - .context("Failed to send request to Kimi")?; - - let status = response.status(); - - if !status.is_success() { - let text = response.text().await.unwrap_or_default(); - - if let Ok(error) = serde_json::from_str::(&text) { - bail!("Kimi API error: {} ({})", error.error.message, error.error.error_type); - } - - bail!("Kimi API error: {} - {}", status, text); - } - - let result: ChatCompletionResponse = response - .json() - .await - .context("Failed to parse Kimi response")?; - - result.choices - .into_iter() - .next() - .map(|c| { - let content = c.message.content.trim().to_string(); - if content.is_empty() { - c.reasoning_content - .map(|r| r.trim().to_string()) - .unwrap_or_default() - } else { - content - } - }) - .filter(|s| !s.is_empty()) - .ok_or_else(|| anyhow::anyhow!( - "No response from Kimi. \ - If thinking mode is enabled, try disabling it or ensure the model supports it." - )) - } -} - -/// Available Kimi models -pub const KIMI_MODELS: &[&str] = &[ - "moonshot-v1-8k", - "moonshot-v1-32k", - "moonshot-v1-128k", -]; - -/// Check if a model name is valid -pub fn is_valid_model(model: &str) -> bool { - KIMI_MODELS.contains(&model) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_model_validation() { - assert!(is_valid_model("moonshot-v1-8k")); - assert!(!is_valid_model("invalid-model")); - } -} +use super::thinking::ThinkingStateManager; +use super::{create_http_client, LlmProvider}; +use anyhow::{bail, Context, Result}; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use std::time::Duration; + +/// Kimi API client (Moonshot AI) +pub struct KimiClient { + base_url: String, + api_key: String, + model: String, + client: reqwest::Client, + thinking_enabled: bool, + max_tokens: u32, + temperature: f32, + thinking_state: Option>, +} + +#[derive(Debug, Serialize)] +struct ChatCompletionRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + thinking: Option, +} + +#[derive(Debug, Serialize)] +struct ThinkingConfig { + #[serde(rename = "type")] + thinking_type: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Message { + role: String, + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_content: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct Choice { + message: Message, + #[serde(default)] + reasoning_content: Option, +} + +// --- Streaming response structures --- + +#[derive(Debug, Deserialize)] +struct StreamChunk { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct StreamChoice { + delta: StreamDelta, + #[serde(default)] + finish_reason: Option, + index: Option, +} + +#[derive(Debug, Deserialize, Default)] +struct StreamDelta { + #[serde(default)] + content: Option, + #[serde(default)] + reasoning_content: Option, +} + +#[derive(Debug, Deserialize)] +struct ErrorResponse { + error: ApiError, +} + +#[derive(Debug, Deserialize)] +struct ApiError { + message: String, + #[serde(rename = "type")] + error_type: String, +} + +impl KimiClient { + pub fn new(api_key: &str, model: &str) -> Result { + let client = create_http_client(Duration::from_secs(300))?; + + Ok(Self { + base_url: "https://api.moonshot.cn/v1".to_string(), + api_key: api_key.to_string(), + model: model.to_string(), + client, + thinking_enabled: false, + max_tokens: 500, + temperature: 1.0, + thinking_state: None, + }) + } + + pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result { + let client = create_http_client(Duration::from_secs(300))?; + + Ok(Self { + base_url: base_url.trim_end_matches('/').to_string(), + api_key: api_key.to_string(), + model: model.to_string(), + client, + thinking_enabled: false, + max_tokens: 500, + temperature: 1.0, + thinking_state: None, + }) + } + + pub fn with_timeout(mut self, timeout: Duration) -> Result { + self.client = create_http_client(timeout)?; + Ok(self) + } + + pub fn with_thinking(mut self, enabled: bool) -> Self { + self.thinking_enabled = enabled; + self + } + + pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = max_tokens; + self + } + + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.temperature = temperature; + self + } + + pub fn with_thinking_state(mut self, state: Arc) -> Self { + self.thinking_state = Some(state); + self + } + + pub async fn list_models(&self) -> Result> { + let url = format!("{}/models", self.base_url); + + let response = self + .client + .get(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .send() + .await + .context("Failed to list Kimi models")?; + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + bail!("Kimi API error: {} - {}", status, text); + } + + #[derive(Deserialize)] + struct ModelsResponse { + data: Vec, + } + + #[derive(Deserialize)] + struct ModelId { + id: String, + } + + let result: ModelsResponse = response + .json() + .await + .context("Failed to parse Kimi response")?; + + Ok(result.data.into_iter().map(|m| m.id).collect()) + } + + pub async fn validate_key(&self) -> Result { + match self.list_models().await { + Ok(_) => Ok(true), + Err(e) => { + let err_str = e.to_string(); + if err_str.contains("401") || err_str.contains("Unauthorized") { + Ok(false) + } else { + Err(e) + } + } + } + } +} + +#[async_trait] +impl LlmProvider for KimiClient { + async fn generate(&self, prompt: &str) -> Result { + let messages = vec![Message { + role: "user".to_string(), + content: prompt.to_string(), + reasoning_content: None, + }]; + + self.chat_completion_with_retry(messages).await + } + + async fn generate_with_system(&self, system: &str, user: &str) -> Result { + let mut messages = vec![]; + + if !system.is_empty() { + messages.push(Message { + role: "system".to_string(), + content: system.to_string(), + reasoning_content: None, + }); + } + + messages.push(Message { + role: "user".to_string(), + content: user.to_string(), + reasoning_content: None, + }); + + self.chat_completion_with_retry(messages).await + } + + async fn is_available(&self) -> bool { + self.validate_key().await.unwrap_or(false) + } + + fn name(&self) -> &str { + "kimi" + } +} + +impl KimiClient { + async fn chat_completion_with_retry(&self, messages: Vec) -> Result { + let mut last_error = None; + + for attempt in 1..=3 { + match self.chat_completion(messages.clone()).await { + Ok(result) => return Ok(result), + Err(e) => { + let err_msg = e.to_string(); + let is_retryable = err_msg.contains("timeout") + || err_msg.contains("connection") + || err_msg.contains("temporary") + || err_msg.contains("5") + && (err_msg.contains("500") + || err_msg.contains("502") + || err_msg.contains("503") + || err_msg.contains("504")); + + if !is_retryable || attempt == 3 { + last_error = Some(e); + break; + } + + tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await; + } + } + } + + Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries"))) + } + + async fn chat_completion(&self, messages: Vec) -> Result { + let url = format!("{}/chat/completions", self.base_url); + + let thinking = if self.thinking_enabled { + Some(ThinkingConfig { + thinking_type: "enabled".to_string(), + }) + } else { + None + }; + + // 对于 kimi-k2.6 等支持思考模式的模型,使用默认 temperature 即可 + // 思考模式下不显式指定 temperature + let temperature = if self.thinking_enabled { + None + } else { + Some(self.temperature) + }; + + let request = ChatCompletionRequest { + model: self.model.clone(), + messages: messages.clone(), + max_tokens: Some(self.max_tokens), + temperature, + stream: self.thinking_enabled, + thinking, + }; + + if self.thinking_enabled { + self.streaming_chat_completion(&url, &request).await + } else { + self.non_streaming_chat_completion(&url, &request).await + } + } + + /// 非流式请求(非思考模式) + async fn non_streaming_chat_completion( + &self, + url: &str, + request: &ChatCompletionRequest, + ) -> Result { + let response = self + .client + .post(url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(request) + .send() + .await + .context("Failed to send request to Kimi")?; + + let status = response.status(); + + if !status.is_success() { + let text = response.text().await.unwrap_or_default(); + + if let Ok(error) = serde_json::from_str::(&text) { + bail!( + "Kimi API error: {} ({})", + error.error.message, + error.error.error_type + ); + } + + bail!("Kimi API error: {} - {}", status, text); + } + + let result: ChatCompletionResponse = response + .json() + .await + .context("Failed to parse Kimi response")?; + + result + .choices + .into_iter() + .next() + .map(|c| c.message.content.trim().to_string()) + .filter(|s| !s.is_empty()) + .ok_or_else(|| anyhow::anyhow!("No response from Kimi")) + } + + /// 流式请求(思考模式),处理 reasoning_content 和 content + async fn streaming_chat_completion( + &self, + url: &str, + request: &ChatCompletionRequest, + ) -> Result { + let response = self + .client + .post(url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream") + .json(request) + .send() + .await + .context("Failed to send streaming request to Kimi")?; + + let status = response.status(); + + if !status.is_success() { + let text = response.text().await.unwrap_or_default(); + + if let Ok(error) = serde_json::from_str::(&text) { + bail!( + "Kimi API error: {} ({})", + error.error.message, + error.error.error_type + ); + } + + bail!("Kimi API error: {} - {}", status, text); + } + + let mut content_buffer = String::new(); + let mut has_reasoning = false; + let mut has_content = false; + let mut stream_ended = false; + + let thinking_state = self.thinking_state.as_ref(); + + let mut byte_stream = response.bytes_stream(); + let mut line_buffer = String::new(); + + use futures_util::StreamExt; + + while let Some(chunk) = byte_stream.next().await { + let chunk = chunk.context("Failed to read streaming response chunk")?; + let chunk_str = + String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?; + + line_buffer.push_str(&chunk_str); + + while let Some(line_end) = line_buffer.find('\n') { + let line = line_buffer[..line_end].trim().to_string(); + line_buffer = line_buffer[line_end + 1..].to_string(); + + if line.is_empty() { + continue; + } + + if line == "data: [DONE]" { + stream_ended = true; + break; + } + + if let Some(json_str) = line.strip_prefix("data: ") { + match serde_json::from_str::(json_str) { + Ok(chunk) => { + for choice in &chunk.choices { + if let Some(ref reasoning) = choice.delta.reasoning_content + && !reasoning.is_empty() { + if !has_reasoning { + has_reasoning = true; + if let Some(state) = thinking_state { + state.start_thinking(); + } + } + continue; + } + + if let Some(ref content) = choice.delta.content + && !content.is_empty() { + if has_reasoning && !has_content + && let Some(state) = thinking_state { + state.end_thinking(); + } + has_content = true; + content_buffer.push_str(content); + } + + if let Some(ref reason) = choice.finish_reason + && reason == "stop" { + stream_ended = true; + } + } + } + Err(_) => { + // 忽略无法解析的行 + } + } + } + } + + if stream_ended { + break; + } + } + + // 确保思考状态已结束 + if let Some(state) = thinking_state { + state.end_thinking(); + } + + let result = content_buffer.trim().to_string(); + + if result.is_empty() { + if has_reasoning && !has_content { + bail!( + "Kimi returned reasoning content but no final answer. \ + The model may have entered an incomplete thinking state. \ + Please try again or disable thinking mode." + ); + } + bail!( + "No response from Kimi. \ + If thinking mode is enabled, try disabling it or ensure the model supports it." + ); + } + + Ok(result) + } +} + +/// 可用 Kimi 模型列表 +pub const KIMI_MODELS: &[&str] = &[ + // K2 系列(推荐) + "kimi-k2.6", + "kimi-k2.5", + "kimi-k2-thinking", + "kimi-k2-thinking-turbo", + "kimi-k2-instruct", + "kimi-k2-instruct-0905", + // 兼容旧版模型 ID + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", +]; + +pub fn is_valid_model(model: &str) -> bool { + KIMI_MODELS.contains(&model) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_model_validation_k2() { + assert!(is_valid_model("kimi-k2.6")); + assert!(is_valid_model("kimi-k2.5")); + assert!(is_valid_model("kimi-k2-thinking")); + assert!(is_valid_model("kimi-k2-thinking-turbo")); + assert!(is_valid_model("moonshot-v1-8k")); + assert!(is_valid_model("moonshot-v1-32k")); + assert!(is_valid_model("moonshot-v1-128k")); + assert!(!is_valid_model("invalid-model")); + assert!(!is_valid_model("kimi-k1.5")); + } + + #[test] + fn test_client_builder_defaults() { + let client = KimiClient::new("test-key", "kimi-k2.6").unwrap(); + assert!(!client.thinking_enabled); + assert_eq!(client.max_tokens, 500); + assert_eq!(client.temperature, 1.0); + assert!(client.thinking_state.is_none()); + } + + #[test] + fn test_client_builder_with_thinking() { + let client = KimiClient::new("test-key", "kimi-k2.6") + .unwrap() + .with_thinking(true) + .with_max_tokens(1000) + .with_temperature(0.5); + + assert!(client.thinking_enabled); + assert_eq!(client.max_tokens, 1000); + assert_eq!(client.temperature, 0.5); + } + + #[test] + fn test_thinking_config_serialization() { + let config = ThinkingConfig { + thinking_type: "enabled".to_string(), + }; + let json = serde_json::to_string(&config).unwrap(); + assert_eq!(json, r#"{"type":"enabled"}"#); + } + + #[test] + fn test_client_new_defaults() { + let client = KimiClient::new("test-key", "kimi-k2.6").unwrap(); + assert_eq!(client.name(), "kimi"); + assert!(!client.thinking_enabled); + } + + #[test] + fn test_message_serialization() { + let msg = Message { + role: "user".to_string(), + content: "Hello".to_string(), + reasoning_content: None, + }; + let json = serde_json::to_string(&msg).unwrap(); + assert!(!json.contains("reasoning_content")); + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 810e4bc..bafbadd 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -9,6 +9,7 @@ pub mod anthropic; pub mod kimi; pub mod deepseek; pub mod openrouter; +pub mod thinking; pub use ollama::OllamaClient; pub use openai::OpenAiClient; @@ -61,48 +62,114 @@ impl Default for LlmClientConfig { impl LlmClient { /// Create LLM client from configuration manager pub async fn from_config(manager: &crate::config::manager::ConfigManager) -> Result { + Self::from_config_with_think(manager, manager.config().llm.thinking_enabled).await + } + + /// Create LLM client from configuration with explicit thinking override + pub async fn from_config_with_think( + manager: &crate::config::manager::ConfigManager, + thinking_enabled: bool, + ) -> Result { let config = manager.config(); let client_config = LlmClientConfig { max_tokens: config.llm.max_tokens, temperature: config.llm.temperature, timeout: Duration::from_secs(config.llm.timeout), - thinking_enabled: config.llm.thinking_enabled, + thinking_enabled, }; let provider = config.llm.provider.as_str(); let model = config.llm.model.as_str(); let base_url = manager.llm_base_url(); let api_key = manager.get_api_key(); - let thinking_enabled = config.llm.thinking_enabled; let provider: Box = match provider { "ollama" => { - Box::new(OllamaClient::new(&base_url, model)) + Box::new(OllamaClient::new(&base_url, model) + .with_max_tokens(client_config.max_tokens) + .with_temperature(client_config.temperature)) } "openai" => { let key = api_key.as_ref() .ok_or_else(|| anyhow::anyhow!("OpenAI API key not configured"))?; - Box::new(OpenAiClient::new(&base_url, key, model)?) + let thinking_state = if thinking_enabled { + Some(thinking::create_console_thinking_state()) + } else { + None + }; + let mut client = OpenAiClient::new(&base_url, key, model)? + .with_thinking(thinking_enabled) + .with_max_tokens(client_config.max_tokens) + .with_temperature(client_config.temperature) + .with_timeout(client_config.timeout)?; + if let Some(state) = thinking_state { + client = client.with_thinking_state(state); + } + Box::new(client) } "anthropic" => { let key = api_key.as_ref() .ok_or_else(|| anyhow::anyhow!("Anthropic API key not configured"))?; - Box::new(AnthropicClient::new(key, model)?) + let thinking_state = if thinking_enabled { + Some(thinking::create_console_thinking_state()) + } else { + None + }; + let budget = config.llm.thinking_budget_tokens.unwrap_or(1024); + let mut client = AnthropicClient::new(key, model)? + .with_thinking(thinking_enabled) + .with_thinking_budget_tokens(budget) + .with_max_tokens(client_config.max_tokens) + .with_temperature(client_config.temperature) + .with_timeout(client_config.timeout)?; + if let Some(state) = thinking_state { + client = client.with_thinking_state(state); + } + Box::new(client) } "kimi" => { let key = api_key.as_ref() .ok_or_else(|| anyhow::anyhow!("Kimi API key not configured"))?; - Box::new(KimiClient::with_base_url(key, model, &base_url)?.with_thinking(thinking_enabled)) + let thinking_state = if thinking_enabled { + Some(thinking::create_console_thinking_state()) + } else { + None + }; + let mut client = KimiClient::with_base_url(key, model, &base_url)? + .with_thinking(thinking_enabled) + .with_max_tokens(client_config.max_tokens) + .with_temperature(client_config.temperature) + .with_timeout(client_config.timeout)?; + if let Some(state) = thinking_state { + client = client.with_thinking_state(state); + } + Box::new(client) } "deepseek" => { let key = api_key.as_ref() .ok_or_else(|| anyhow::anyhow!("DeepSeek API key not configured"))?; - Box::new(DeepSeekClient::with_base_url(key, model, &base_url)?.with_thinking(thinking_enabled)) + let thinking_state = if thinking_enabled { + Some(thinking::create_console_thinking_state()) + } else { + None + }; + let mut client = DeepSeekClient::with_base_url(key, model, &base_url)? + .with_thinking(thinking_enabled) + .with_max_tokens(client_config.max_tokens) + .with_temperature(client_config.temperature) + .with_timeout(client_config.timeout)?; + if let Some(state) = thinking_state { + client = client.with_thinking_state(state); + } + Box::new(client) } "openrouter" => { let key = api_key.as_ref() .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not configured"))?; - Box::new(OpenRouterClient::with_base_url(key, model, &base_url)?) + Box::new(OpenRouterClient::with_base_url(key, model, &base_url)? + .with_max_tokens(client_config.max_tokens) + .with_temperature(client_config.temperature) + .with_timeout(client_config.timeout)?) } _ => bail!("Unknown LLM provider: {}", provider), }; diff --git a/src/llm/ollama.rs b/src/llm/ollama.rs index 372b18f..33540f8 100644 --- a/src/llm/ollama.rs +++ b/src/llm/ollama.rs @@ -9,6 +9,9 @@ pub struct OllamaClient { base_url: String, model: String, client: reqwest::Client, + max_tokens: u32, + temperature: f32, + top_p: Option, } #[derive(Debug, Serialize)] @@ -49,11 +52,14 @@ impl OllamaClient { pub fn new(base_url: &str, model: &str) -> Self { let client = create_http_client(Duration::from_secs(120)) .expect("Failed to create HTTP client"); - + Self { base_url: base_url.trim_end_matches('/').to_string(), model: model.to_string(), client, + max_tokens: 500, + temperature: 0.7, + top_p: None, } } @@ -64,6 +70,21 @@ impl OllamaClient { self } + pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = max_tokens; + self + } + + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.temperature = temperature; + self + } + + pub fn with_top_p(mut self, top_p: f32) -> Self { + self.top_p = Some(top_p); + self + } + /// List available models pub async fn list_models(&self) -> Result> { let url = format!("{}/api/tags", self.base_url); @@ -143,8 +164,8 @@ impl LlmProvider for OllamaClient { system, stream: false, options: GenerationOptions { - temperature: Some(0.7), - num_predict: Some(500), + temperature: Some(self.temperature), + num_predict: Some(self.max_tokens), }, }; diff --git a/src/llm/openai.rs b/src/llm/openai.rs index 3f85c40..8785ab3 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -1,15 +1,23 @@ +use super::thinking::ThinkingStateManager; use super::{create_http_client, LlmProvider}; use anyhow::{bail, Context, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; +use std::sync::Arc; use std::time::Duration; -/// OpenAI API client +/// OpenAI API client with o-series reasoning support pub struct OpenAiClient { base_url: String, api_key: String, model: String, client: reqwest::Client, + thinking_enabled: bool, + reasoning_effort: Option, + max_tokens: u32, + temperature: f32, + top_p: Option, + thinking_state: Option>, } #[derive(Debug, Serialize)] @@ -20,10 +28,14 @@ struct ChatCompletionRequest { max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_effort: Option, stream: bool, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] struct Message { role: String, content: String, @@ -39,6 +51,28 @@ struct Choice { message: Message, } +// --- Streaming response structures --- + +#[derive(Debug, Deserialize)] +struct StreamChunk { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct StreamChoice { + delta: StreamDelta, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Deserialize, Default)] +struct StreamDelta { + #[serde(default)] + content: Option, + #[serde(default)] + reasoning_content: Option, +} + #[derive(Debug, Deserialize)] struct ErrorResponse { error: ApiError, @@ -55,57 +89,91 @@ impl OpenAiClient { /// Create new OpenAI client pub fn new(base_url: &str, api_key: &str, model: &str) -> Result { let client = create_http_client(Duration::from_secs(60))?; - + Ok(Self { base_url: base_url.trim_end_matches('/').to_string(), api_key: api_key.to_string(), model: model.to_string(), client, + thinking_enabled: false, + reasoning_effort: None, + max_tokens: 500, + temperature: 0.7, + top_p: None, + thinking_state: None, }) } - /// Set timeout pub fn with_timeout(mut self, timeout: Duration) -> Result { self.client = create_http_client(timeout)?; Ok(self) } - /// List available models + pub fn with_thinking(mut self, enabled: bool) -> Self { + self.thinking_enabled = enabled; + self + } + + pub fn with_reasoning_effort(mut self, effort: Option) -> Self { + self.reasoning_effort = effort; + self + } + + pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = max_tokens; + self + } + + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.temperature = temperature; + self + } + + pub fn with_top_p(mut self, top_p: f32) -> Self { + self.top_p = Some(top_p); + self + } + + pub fn with_thinking_state(mut self, state: Arc) -> Self { + self.thinking_state = Some(state); + self + } + pub async fn list_models(&self) -> Result> { let url = format!("{}/models", self.base_url); - - let response = self.client + + let response = self + .client .get(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .send() .await .context("Failed to list OpenAI models")?; - + if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); bail!("OpenAI API error: {} - {}", status, text); } - + #[derive(Deserialize)] struct ModelsResponse { data: Vec, } - + #[derive(Deserialize)] struct Model { id: String, } - + let result: ModelsResponse = response .json() .await .context("Failed to parse OpenAI response")?; - + Ok(result.data.into_iter().map(|m| m.id).collect()) } - /// Validate API key pub async fn validate_key(&self) -> Result { match self.list_models().await { Ok(_) => Ok(true), @@ -124,32 +192,30 @@ impl OpenAiClient { #[async_trait] impl LlmProvider for OpenAiClient { async fn generate(&self, prompt: &str) -> Result { - let messages = vec![ - Message { - role: "user".to_string(), - content: prompt.to_string(), - }, - ]; - - self.chat_completion(messages).await + let messages = vec![Message { + role: "user".to_string(), + content: prompt.to_string(), + }]; + + self.chat_completion_with_retry(messages).await } async fn generate_with_system(&self, system: &str, user: &str) -> Result { let mut messages = vec![]; - + if !system.is_empty() { messages.push(Message { role: "system".to_string(), content: system.to_string(), }); } - + messages.push(Message { role: "user".to_string(), content: user.to_string(), }); - - self.chat_completion(messages).await + + self.chat_completion_with_retry(messages).await } async fn is_available(&self) -> bool { @@ -162,18 +228,59 @@ impl LlmProvider for OpenAiClient { } impl OpenAiClient { + async fn chat_completion_with_retry(&self, messages: Vec) -> Result { + let mut last_error = None; + + for attempt in 1..=3 { + match self.chat_completion(messages.clone()).await { + Ok(result) => return Ok(result), + Err(e) => { + let err_msg = e.to_string(); + let is_retryable = err_msg.contains("timeout") + || err_msg.contains("connection") + || err_msg.contains("temporary") + || err_msg.contains("5") + && (err_msg.contains("500") + || err_msg.contains("502") + || err_msg.contains("503") + || err_msg.contains("504")); + + if !is_retryable || attempt == 3 { + last_error = Some(e); + break; + } + + tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await; + } + } + } + + Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries"))) + } + async fn chat_completion(&self, messages: Vec) -> Result { + if self.thinking_enabled { + self.streaming_chat_completion(messages).await + } else { + self.non_streaming_chat_completion(messages).await + } + } + + async fn non_streaming_chat_completion(&self, messages: Vec) -> Result { let url = format!("{}/chat/completions", self.base_url); - + let request = ChatCompletionRequest { model: self.model.clone(), messages, - max_tokens: Some(500), - temperature: Some(0.7), + max_tokens: Some(self.max_tokens), + temperature: Some(self.temperature), + top_p: self.top_p, + reasoning_effort: None, stream: false, }; - - let response = self.client + + let response = self + .client .post(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .header("Content-Type", "application/json") @@ -181,31 +288,165 @@ impl OpenAiClient { .send() .await .context("Failed to send request to OpenAI")?; - + let status = response.status(); - + if !status.is_success() { let text = response.text().await.unwrap_or_default(); - - // Try to parse error + if let Ok(error) = serde_json::from_str::(&text) { - bail!("OpenAI API error: {} ({})", error.error.message, error.error.error_type); + bail!( + "OpenAI API error: {} ({})", + error.error.message, + error.error.error_type + ); } - + bail!("OpenAI API error: {} - {}", status, text); } - + let result: ChatCompletionResponse = response .json() .await .context("Failed to parse OpenAI response")?; - - result.choices + + result + .choices .into_iter() .next() .map(|c| c.message.content.trim().to_string()) + .filter(|s| !s.is_empty()) .ok_or_else(|| anyhow::anyhow!("No response from OpenAI")) } + + /// Streaming request for reasoning mode, filters reasoning_content from output + async fn streaming_chat_completion(&self, messages: Vec) -> Result { + let url = format!("{}/chat/completions", self.base_url); + + // For reasoning/thinking mode, omit temperature and top_p + let request = ChatCompletionRequest { + model: self.model.clone(), + messages, + max_tokens: Some(self.max_tokens), + temperature: None, + top_p: None, + reasoning_effort: self.reasoning_effort.clone(), + stream: true, + }; + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream") + .json(&request) + .send() + .await + .context("Failed to send streaming request to OpenAI")?; + + let status = response.status(); + + if !status.is_success() { + let text = response.text().await.unwrap_or_default(); + + if let Ok(error) = serde_json::from_str::(&text) { + bail!( + "OpenAI API error: {} ({})", + error.error.message, + error.error.error_type + ); + } + + bail!("OpenAI API error: {} - {}", status, text); + } + + let mut content_buffer = String::new(); + let mut has_reasoning = false; + let mut has_content = false; + + let thinking_state = self.thinking_state.as_ref(); + + let mut byte_stream = response.bytes_stream(); + let mut line_buffer = String::new(); + + use futures_util::StreamExt; + + while let Some(chunk) = byte_stream.next().await { + let chunk = chunk.context("Failed to read streaming response chunk")?; + let chunk_str = + String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?; + + line_buffer.push_str(&chunk_str); + + while let Some(line_end) = line_buffer.find('\n') { + let line = line_buffer[..line_end].trim().to_string(); + line_buffer = line_buffer[line_end + 1..].to_string(); + + if line.is_empty() { + continue; + } + + if line == "data: [DONE]" { + break; + } + + if let Some(json_str) = line.strip_prefix("data: ") { + if let Ok(chunk) = serde_json::from_str::(json_str) { + for choice in &chunk.choices { + // Handle reasoning_content (o-series) + if let Some(ref reasoning) = choice.delta.reasoning_content + && !reasoning.is_empty() + { + if !has_reasoning { + has_reasoning = true; + if let Some(state) = thinking_state { + state.start_thinking(); + } + } + continue; + } + + // Handle content + if let Some(ref content) = choice.delta.content + && !content.is_empty() + { + if has_reasoning && !has_content + && let Some(state) = thinking_state + { + state.end_thinking(); + } + has_content = true; + content_buffer.push_str(content); + } + } + } + } + } + } + + if let Some(state) = thinking_state { + state.end_thinking(); + } + + let result = content_buffer.trim().to_string(); + + if result.is_empty() { + if has_reasoning && !has_content { + bail!( + "OpenAI returned reasoning content but no final answer. \ + The model may have entered an incomplete reasoning state. \ + Please try again or disable thinking mode." + ); + } + bail!( + "No response from OpenAI. \ + If thinking mode is enabled, try disabling it or ensure the model supports reasoning." + ); + } + + Ok(result) + } } /// Azure OpenAI client (extends OpenAI with Azure-specific config) @@ -215,10 +456,15 @@ pub struct AzureOpenAiClient { deployment: String, api_version: String, client: reqwest::Client, + thinking_enabled: bool, + reasoning_effort: Option, + max_tokens: u32, + temperature: f32, + top_p: Option, + thinking_state: Option>, } impl AzureOpenAiClient { - /// Create new Azure OpenAI client pub fn new( endpoint: &str, api_key: &str, @@ -226,13 +472,19 @@ impl AzureOpenAiClient { api_version: &str, ) -> Result { let client = create_http_client(Duration::from_secs(60))?; - + Ok(Self { endpoint: endpoint.trim_end_matches('/').to_string(), api_key: api_key.to_string(), deployment: deployment.to_string(), api_version: api_version.to_string(), client, + thinking_enabled: false, + reasoning_effort: None, + max_tokens: 500, + temperature: 0.7, + top_p: None, + thinking_state: None, }) } @@ -241,16 +493,19 @@ impl AzureOpenAiClient { "{}/openai/deployments/{}/chat/completions?api-version={}", self.endpoint, self.deployment, self.api_version ); - + let request = ChatCompletionRequest { model: self.deployment.clone(), messages, - max_tokens: Some(500), - temperature: Some(0.7), + max_tokens: Some(self.max_tokens), + temperature: Some(self.temperature), + top_p: self.top_p, + reasoning_effort: self.reasoning_effort.clone(), stream: false, }; - - let response = self.client + + let response = self + .client .post(&url) .header("api-key", &self.api_key) .header("Content-Type", "application/json") @@ -258,22 +513,24 @@ impl AzureOpenAiClient { .send() .await .context("Failed to send request to Azure OpenAI")?; - + if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); bail!("Azure OpenAI API error: {} - {}", status, text); } - + let result: ChatCompletionResponse = response .json() .await .context("Failed to parse Azure OpenAI response")?; - - result.choices + + result + .choices .into_iter() .next() .map(|c| c.message.content.trim().to_string()) + .filter(|s| !s.is_empty()) .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI")) } } @@ -281,41 +538,38 @@ impl AzureOpenAiClient { #[async_trait] impl LlmProvider for AzureOpenAiClient { async fn generate(&self, prompt: &str) -> Result { - let messages = vec![ - Message { - role: "user".to_string(), - content: prompt.to_string(), - }, - ]; - + let messages = vec![Message { + role: "user".to_string(), + content: prompt.to_string(), + }]; + self.chat_completion(messages).await } async fn generate_with_system(&self, system: &str, user: &str) -> Result { let mut messages = vec![]; - + if !system.is_empty() { messages.push(Message { role: "system".to_string(), content: system.to_string(), }); } - + messages.push(Message { role: "user".to_string(), content: user.to_string(), }); - + self.chat_completion(messages).await } async fn is_available(&self) -> bool { - // Simple check - try to make a minimal request let url = format!( "{}/openai/deployments/{}/chat/completions?api-version={}", self.endpoint, self.deployment, self.api_version ); - + let request = ChatCompletionRequest { model: self.deployment.clone(), messages: vec![Message { @@ -324,10 +578,13 @@ impl LlmProvider for AzureOpenAiClient { }], max_tokens: Some(5), temperature: Some(0.0), + top_p: None, + reasoning_effort: None, stream: false, }; - - match self.client + + match self + .client .post(&url) .header("api-key", &self.api_key) .json(&request) @@ -343,3 +600,59 @@ impl LlmProvider for AzureOpenAiClient { "azure-openai" } } + +/// Available OpenAI models (including o-series with reasoning) +pub const OPENAI_MODELS: &[&str] = &[ + "o4-mini", + "o3", + "o3-mini", + "o1", + "o1-mini", + "o1-pro", + "gpt-4.1", + "gpt-4.1-mini", + "gpt-4.1-nano", + "gpt-4o", + "gpt-4o-mini", + "gpt-4-turbo", + "gpt-4", + "gpt-3.5-turbo", +]; + +pub fn is_valid_model(model: &str) -> bool { + OPENAI_MODELS.contains(&model) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_model_validation_o_series() { + assert!(is_valid_model("o4-mini")); + assert!(is_valid_model("o3")); + assert!(is_valid_model("o1")); + assert!(is_valid_model("gpt-4o")); + assert!(is_valid_model("gpt-3.5-turbo")); + assert!(!is_valid_model("invalid-model")); + } + + #[test] + fn test_stream_delta_reasoning_parsing() { + let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#; + let delta: StreamDelta = serde_json::from_str(json).unwrap(); + assert!(delta.content.is_none()); + assert_eq!( + delta.reasoning_content, + Some("Let me think...".to_string()) + ); + } + + #[test] + fn test_stream_delta_content_parsing() { + let json = r#"{"content":"Hello","reasoning_content":null}"#; + let delta: StreamDelta = serde_json::from_str(json).unwrap(); + assert_eq!(delta.content, Some("Hello".to_string())); + assert!(delta.reasoning_content.is_none()); + } +} diff --git a/src/llm/openrouter.rs b/src/llm/openrouter.rs index 8b0f0f6..d5bc665 100644 --- a/src/llm/openrouter.rs +++ b/src/llm/openrouter.rs @@ -10,6 +10,9 @@ pub struct OpenRouterClient { api_key: String, model: String, client: reqwest::Client, + max_tokens: u32, + temperature: f32, + top_p: Option, } #[derive(Debug, Serialize)] @@ -55,24 +58,30 @@ impl OpenRouterClient { /// Create new OpenRouter client pub fn new(api_key: &str, model: &str) -> Result { let client = create_http_client(Duration::from_secs(60))?; - + Ok(Self { base_url: "https://openrouter.ai/api/v1".to_string(), api_key: api_key.to_string(), model: model.to_string(), client, + max_tokens: 500, + temperature: 0.7, + top_p: None, }) } /// Create with custom base URL pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result { let client = create_http_client(Duration::from_secs(60))?; - + Ok(Self { base_url: base_url.trim_end_matches('/').to_string(), api_key: api_key.to_string(), model: model.to_string(), client, + max_tokens: 500, + temperature: 0.7, + top_p: None, }) } @@ -82,6 +91,21 @@ impl OpenRouterClient { Ok(self) } + pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = max_tokens; + self + } + + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.temperature = temperature; + self + } + + pub fn with_top_p(mut self, top_p: f32) -> Self { + self.top_p = Some(top_p); + self + } + /// List available models pub async fn list_models(&self) -> Result> { let url = format!("{}/models", self.base_url); @@ -182,8 +206,8 @@ impl OpenRouterClient { let request = ChatCompletionRequest { model: self.model.clone(), messages, - max_tokens: Some(500), - temperature: Some(0.7), + max_tokens: Some(self.max_tokens), + temperature: Some(self.temperature), stream: false, }; diff --git a/src/llm/thinking.rs b/src/llm/thinking.rs new file mode 100644 index 0000000..a0c43f5 --- /dev/null +++ b/src/llm/thinking.rs @@ -0,0 +1,152 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +/// 统一的思考状态管理器,用于管理模型思考状态的显示与隐藏 +pub struct ThinkingStateManager { + is_thinking: AtomicBool, + on_start: Option>, + on_end: Option>, +} + +impl ThinkingStateManager { + pub fn new() -> Self { + Self { + is_thinking: AtomicBool::new(false), + on_start: None, + on_end: None, + } + } + + /// 设置思考开始回调 + pub fn on_thinking_start(mut self, callback: F) -> Self { + self.on_start = Some(Box::new(callback)); + self + } + + /// 设置思考结束回调 + pub fn on_thinking_end(mut self, callback: F) -> Self { + self.on_end = Some(Box::new(callback)); + self + } + + /// 开始思考状态 + pub fn start_thinking(&self) { + if !self.is_thinking.load(Ordering::SeqCst) { + self.is_thinking.store(true, Ordering::SeqCst); + if let Some(ref cb) = self.on_start { + cb(); + } + } + } + + /// 结束思考状态 + pub fn end_thinking(&self) { + if self.is_thinking.load(Ordering::SeqCst) { + self.is_thinking.store(false, Ordering::SeqCst); + if let Some(ref cb) = self.on_end { + cb(); + } + } + } + + /// 当前是否处于思考状态 + pub fn is_thinking(&self) -> bool { + self.is_thinking.load(Ordering::SeqCst) + } +} + +impl Default for ThinkingStateManager { + fn default() -> Self { + Self::new() + } +} + +/// 线程安全的思考状态管理器引用 +pub type SharedThinkingState = Arc; + +/// 创建带有默认控制台输出的思考状态管理器 +/// 在思考开始时打印 "thinking...",在思考结束时清除该标识 +pub fn create_console_thinking_state() -> SharedThinkingState { + Arc::new( + ThinkingStateManager::new() + .on_thinking_start(|| { + eprint!("\rthinking..."); + }) + .on_thinking_end(|| { + eprint!("\r \r"); + }), + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + + #[test] + fn test_thinking_state_transitions() { + let manager = ThinkingStateManager::new(); + assert!(!manager.is_thinking()); + + manager.start_thinking(); + assert!(manager.is_thinking()); + + manager.end_thinking(); + assert!(!manager.is_thinking()); + } + + #[test] + fn test_thinking_idempotent_start() { + let manager = ThinkingStateManager::new(); + manager.start_thinking(); + manager.start_thinking(); // 重复调用不应触发回调两次 + assert!(manager.is_thinking()); + } + + #[test] + fn test_thinking_idempotent_end() { + let manager = ThinkingStateManager::new(); + manager.end_thinking(); // 未开始时结束不应触发问题 + assert!(!manager.is_thinking()); + } + + #[test] + fn test_thinking_callbacks() { + let events: Arc>> = Arc::new(Mutex::new(Vec::new())); + let events_clone = events.clone(); + + let manager = ThinkingStateManager::new() + .on_thinking_start(move || { + events_clone.lock().unwrap().push("start".to_string()); + }); + + let events_clone2 = events.clone(); + let manager = manager.on_thinking_end(move || { + events_clone2.lock().unwrap().push("end".to_string()); + }); + + manager.start_thinking(); + manager.end_thinking(); + + let recorded = events.lock().unwrap(); + assert_eq!(recorded.len(), 2); + assert_eq!(recorded[0], "start"); + assert_eq!(recorded[1], "end"); + } + + #[test] + fn test_create_console_thinking_state() { + let state = create_console_thinking_state(); + assert!(!state.is_thinking()); + state.start_thinking(); + assert!(state.is_thinking()); + state.end_thinking(); + assert!(!state.is_thinking()); + } + + #[test] + fn test_default() { + let manager = ThinkingStateManager::default(); + assert!(!manager.is_thinking()); + } +} diff --git a/src/utils/keyring.rs b/src/utils/keyring.rs index fb68c81..d7c849a 100644 --- a/src/utils/keyring.rs +++ b/src/utils/keyring.rs @@ -85,11 +85,10 @@ impl KeyringManager { } pub fn get_api_key(&self, provider: &str) -> Result> { - if let Ok(key) = env::var(ENV_API_KEY) { - if !key.is_empty() { + if let Ok(key) = env::var(ENV_API_KEY) + && !key.is_empty() { return Ok(Some(key)); } - } if !self.is_available() { return Ok(None); @@ -251,8 +250,8 @@ pub fn get_default_model(provider: &str) -> &'static str { match provider { "openai" => "gpt-4", "anthropic" => "claude-3-sonnet-20240229", - "kimi" => "moonshot-v1-8k", - "deepseek" => "deepseek-chat", + "kimi" => "kimi-k2.6", + "deepseek" => "deepseek-v4-flash", "openrouter" => "openai/gpt-3.5-turbo", "ollama" => "llama2", _ => "", diff --git a/tests/config_export_import_tests.rs b/tests/config_export_import_tests.rs index 3d04420..3d35ac4 100644 --- a/tests/config_export_import_tests.rs +++ b/tests/config_export_import_tests.rs @@ -1,38 +1,11 @@ -use assert_cmd::Command; +use assert_cmd::cargo::cargo_bin_cmd; use predicates::prelude::*; use std::fs; use std::path::PathBuf; use tempfile::TempDir; -fn create_git_repo(dir: &PathBuf) -> std::process::Output { - std::process::Command::new("git") - .args(&["init"]) - .current_dir(dir) - .output() - .expect("Failed to init git repo") -} - -fn configure_git_user(dir: &PathBuf) { - std::process::Command::new("git") - .args(&["config", "user.name", "Test User"]) - .current_dir(dir) - .output() - .expect("Failed to configure git user name"); - - std::process::Command::new("git") - .args(&["config", "user.email", "test@example.com"]) - .current_dir(dir) - .output() - .expect("Failed to configure git user email"); -} - -fn setup_git_repo(dir: &PathBuf) { - create_git_repo(dir); - configure_git_user(dir); -} - fn init_quicommit(config_path: &PathBuf) { - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); } @@ -46,7 +19,7 @@ mod config_export { let config_path = temp_dir.path().join("config.toml"); init_quicommit(&config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["config", "export", "--config", config_path.to_str().unwrap()]); cmd.assert() @@ -62,7 +35,7 @@ mod config_export { let export_path = temp_dir.path().join("exported.toml"); init_quicommit(&config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "export", "--config", config_path.to_str().unwrap(), @@ -88,7 +61,7 @@ mod config_export { let export_path = temp_dir.path().join("encrypted.toml"); init_quicommit(&config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "export", "--config", config_path.to_str().unwrap(), @@ -164,7 +137,7 @@ keep_changelog_types_english = true "#; fs::write(&import_path, plain_config).unwrap(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "import", "--config", config_path.to_str().unwrap(), @@ -175,7 +148,7 @@ keep_changelog_types_english = true .success() .stdout(predicate::str::contains("Configuration imported")); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["config", "get", "llm.provider", "--config", config_path.to_str().unwrap()]); cmd.assert() .success() @@ -191,14 +164,14 @@ keep_changelog_types_english = true init_quicommit(&config_path1); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "set", "llm.provider", "anthropic", "--config", config_path1.to_str().unwrap() ]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "export", "--config", config_path1.to_str().unwrap(), @@ -207,7 +180,7 @@ keep_changelog_types_english = true ]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "import", "--config", config_path2.to_str().unwrap(), @@ -218,7 +191,7 @@ keep_changelog_types_english = true .success() .stdout(predicate::str::contains("Configuration imported")); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["config", "get", "llm.provider", "--config", config_path2.to_str().unwrap()]); cmd.assert() .success() @@ -233,7 +206,7 @@ keep_changelog_types_english = true init_quicommit(&config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "export", "--config", config_path.to_str().unwrap(), @@ -242,7 +215,7 @@ keep_changelog_types_english = true ]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "import", "--config", config_path.to_str().unwrap(), @@ -267,14 +240,14 @@ mod config_export_import_roundtrip { init_quicommit(&config_path1); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "set", "llm.model", "gpt-4-turbo", "--config", config_path1.to_str().unwrap() ]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "export", "--config", config_path1.to_str().unwrap(), @@ -283,7 +256,7 @@ mod config_export_import_roundtrip { ]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "import", "--config", config_path2.to_str().unwrap(), @@ -291,7 +264,7 @@ mod config_export_import_roundtrip { ]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["config", "get", "llm.model", "--config", config_path2.to_str().unwrap()]); cmd.assert() .success() @@ -308,21 +281,21 @@ mod config_export_import_roundtrip { init_quicommit(&config_path1); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "set", "llm.provider", "deepseek", "--config", config_path1.to_str().unwrap() ]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "set", "llm.model", "deepseek-chat", "--config", config_path1.to_str().unwrap() ]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "export", "--config", config_path1.to_str().unwrap(), @@ -335,7 +308,7 @@ mod config_export_import_roundtrip { assert!(exported_content.starts_with("ENCRYPTED:")); assert!(!exported_content.contains("deepseek")); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&[ "config", "import", "--config", config_path2.to_str().unwrap(), @@ -344,13 +317,13 @@ mod config_export_import_roundtrip { ]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["config", "get", "llm.provider", "--config", config_path2.to_str().unwrap()]); cmd.assert() .success() .stdout(predicate::str::contains("deepseek")); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["config", "get", "llm.model", "--config", config_path2.to_str().unwrap()]); cmd.assert() .success() diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index d3016c5..c13713c 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -1,4 +1,4 @@ -use assert_cmd::Command; +use assert_cmd::cargo::cargo_bin_cmd; use predicates::prelude::*; use std::fs; use std::path::PathBuf; @@ -59,7 +59,7 @@ fn setup_test_repo_with_file(dir: &PathBuf, file_name: &str, file_content: &str) } fn init_quicommit(dir: &PathBuf, config_path: &PathBuf) { - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(dir); cmd.assert().success(); @@ -70,7 +70,7 @@ mod cli_basic { #[test] fn test_help() { - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.arg("--help"); cmd.assert() .success() @@ -83,7 +83,7 @@ mod cli_basic { #[test] fn test_version() { - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.arg("--version"); cmd.assert() .success() @@ -92,7 +92,7 @@ mod cli_basic { #[test] fn test_no_args_shows_help() { - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.assert() .failure() .stderr(predicate::str::contains("Usage:")); @@ -106,7 +106,7 @@ mod cli_basic { create_git_repo(&repo_path); configure_git_user(&repo_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["-vv", "init", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path); @@ -122,7 +122,7 @@ mod init_command { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert() @@ -135,7 +135,7 @@ mod init_command { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); @@ -152,7 +152,7 @@ mod init_command { let config_path = repo_path.join("test_config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path); @@ -164,11 +164,11 @@ mod init_command { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--reset", "--config", config_path.to_str().unwrap()]); cmd.assert() .success() @@ -184,7 +184,7 @@ mod profile_command { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["profile", "list", "--config", config_path.to_str().unwrap()]); cmd.assert() @@ -197,11 +197,11 @@ mod profile_command { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["profile", "list", "--config", config_path.to_str().unwrap()]); cmd.assert() @@ -218,11 +218,11 @@ mod config_command { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["config", "show", "--config", config_path.to_str().unwrap()]); cmd.assert() @@ -235,11 +235,11 @@ mod config_command { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["config", "path", "--config", config_path.to_str().unwrap()]); cmd.assert() @@ -256,11 +256,11 @@ mod commit_command { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["commit", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(temp_dir.path()); @@ -278,7 +278,7 @@ mod commit_command { let config_path = repo_path.join("config.toml"); init_quicommit(&repo_path, &config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["commit", "--manual", "-m", "test: empty", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path); @@ -296,7 +296,7 @@ mod commit_command { let config_path = repo_path.join("config.toml"); init_quicommit(&repo_path, &config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["commit", "--manual", "-m", "test: add test file", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path); @@ -314,7 +314,7 @@ mod commit_command { let config_path = repo_path.join("config.toml"); init_quicommit(&repo_path, &config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["commit", "--date", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path); @@ -322,6 +322,22 @@ mod commit_command { .success() .stdout(predicate::str::contains("Dry run")); } + + #[test] + fn test_commit_with_think_flag() { + let temp_dir = TempDir::new().unwrap(); + let repo_path = temp_dir.path().to_path_buf(); + setup_test_repo_with_file(&repo_path, "test.txt", "Hello, World!"); + + let config_path = repo_path.join("config.toml"); + init_quicommit(&repo_path, &config_path); + + let mut cmd = cargo_bin_cmd!("quicommit"); + cmd.args(&["commit", "--think", "--manual", "-m", "test: think flag", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) + .current_dir(&repo_path); + + cmd.assert().success(); + } } mod tag_command { @@ -332,11 +348,11 @@ mod tag_command { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["tag", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(temp_dir.path()); @@ -358,7 +374,7 @@ mod tag_command { let config_path = repo_path.join("config.toml"); init_quicommit(&repo_path, &config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["tag", "--name", "v0.1.0", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path); @@ -366,6 +382,26 @@ mod tag_command { .success() .stdout(predicate::str::contains("v0.1.0")); } + + #[test] + fn test_tag_with_think_flag() { + let temp_dir = TempDir::new().unwrap(); + let repo_path = temp_dir.path().to_path_buf(); + setup_git_repo(&repo_path); + + create_test_file(&repo_path, "test.txt", "content"); + stage_file(&repo_path, "test.txt"); + create_commit(&repo_path, "feat: initial commit"); + + let config_path = repo_path.join("config.toml"); + init_quicommit(&repo_path, &config_path); + + let mut cmd = cargo_bin_cmd!("quicommit"); + cmd.args(&["tag", "--think", "--name", "v0.2.0", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) + .current_dir(&repo_path); + + cmd.assert().success(); + } } mod changelog_command { @@ -382,7 +418,7 @@ mod changelog_command { init_quicommit(&repo_path, &config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["changelog", "--init", "--output", changelog_path.to_str().unwrap(), "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path); @@ -404,7 +440,7 @@ mod changelog_command { let config_path = repo_path.join("config.toml"); init_quicommit(&repo_path, &config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["changelog", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path); @@ -421,7 +457,7 @@ mod cross_platform { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("subdir").join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); @@ -435,7 +471,7 @@ mod cross_platform { fs::create_dir_all(&space_dir).unwrap(); let config_path = space_dir.join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); @@ -449,7 +485,7 @@ mod cross_platform { fs::create_dir_all(&unicode_dir).unwrap(); let config_path = unicode_dir.join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); @@ -523,7 +559,7 @@ mod validators { let config_path = repo_path.join("config.toml"); init_quicommit(&repo_path, &config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["commit", "--manual", "-m", "invalid commit message without type", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path); @@ -541,7 +577,7 @@ mod validators { let config_path = repo_path.join("config.toml"); init_quicommit(&repo_path, &config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["commit", "--manual", "-m", "feat: add new feature", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path); @@ -563,7 +599,7 @@ mod subcommands { let config_path = repo_path.join("config.toml"); init_quicommit(&repo_path, &config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["c", "--manual", "-m", "fix: test", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path); @@ -577,7 +613,7 @@ mod subcommands { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["i", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert() @@ -590,11 +626,11 @@ mod subcommands { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path().join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["p", "list", "--config", config_path.to_str().unwrap()]); cmd.assert() @@ -611,7 +647,7 @@ mod edge_cases { let temp_dir = TempDir::new().unwrap(); let non_existent_config = temp_dir.path().join("non_existent_config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["config", "show", "--config", non_existent_config.to_str().unwrap()]); cmd.assert() @@ -627,11 +663,11 @@ mod edge_cases { let repo_path = temp_dir.path().to_path_buf(); let config_path = repo_path.join("config.toml"); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]); cmd.assert().success(); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["commit", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path); @@ -649,7 +685,7 @@ mod edge_cases { let config_path = repo_path.join("config.toml"); init_quicommit(&repo_path, &config_path); - let mut cmd = Command::cargo_bin("quicommit").unwrap(); + let mut cmd = cargo_bin_cmd!("quicommit"); cmd.args(&["commit", "--manual", "-m", "", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()]) .current_dir(&repo_path);