LLM支持优化
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -21,3 +21,5 @@ test_output/
|
|||||||
|
|
||||||
# Config (for development)
|
# Config (for development)
|
||||||
config.toml
|
config.toml
|
||||||
|
.claude/
|
||||||
|
CLAUDE.md
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "quicommit"
|
name = "quicommit"
|
||||||
version = "0.2.1"
|
version = "0.2.3"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
authors = ["Sidney Zhang <zly@lyzhang.me>"]
|
authors = ["Sidney Zhang <zly@lyzhang.me>"]
|
||||||
description = "A powerful Git assistant tool with AI-powered commit/tag/changelog generation(alpha version)"
|
description = "A powerful Git assistant tool with AI-powered commit/tag/changelog generation(alpha version)"
|
||||||
@@ -33,7 +33,7 @@ git2 = "0.20.3"
|
|||||||
which = "6.0"
|
which = "6.0"
|
||||||
|
|
||||||
# HTTP client for LLM APIs
|
# HTTP client for LLM APIs
|
||||||
reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false }
|
reqwest = { version = "0.12", features = ["json", "rustls-tls", "stream"], default-features = false }
|
||||||
tokio = { version = "1.35", features = ["full"] }
|
tokio = { version = "1.35", features = ["full"] }
|
||||||
|
|
||||||
# Error handling
|
# Error handling
|
||||||
@@ -57,6 +57,7 @@ sha2 = "0.10"
|
|||||||
hex = "0.4"
|
hex = "0.4"
|
||||||
textwrap = "0.16"
|
textwrap = "0.16"
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
|
futures-util = "0.3"
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
atty = "0.2"
|
atty = "0.2"
|
||||||
|
|
||||||
|
|||||||
@@ -55,6 +55,10 @@ pub struct ChangelogCommand {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
dry_run: bool,
|
dry_run: bool,
|
||||||
|
|
||||||
|
/// Enable thinking mode for this changelog (override config)
|
||||||
|
#[arg(long)]
|
||||||
|
think: bool,
|
||||||
|
|
||||||
/// Skip interactive prompts
|
/// Skip interactive prompts
|
||||||
#[arg(short = 'y', long)]
|
#[arg(short = 'y', long)]
|
||||||
yes: bool,
|
yes: bool,
|
||||||
@@ -74,8 +78,7 @@ impl ChangelogCommand {
|
|||||||
|
|
||||||
// Initialize changelog if requested
|
// Initialize changelog if requested
|
||||||
if self.init {
|
if self.init {
|
||||||
let path = self.output.as_ref()
|
let path = self.output.clone()
|
||||||
.map(|p| p.clone())
|
|
||||||
.unwrap_or_else(|| PathBuf::from(&config.changelog.path));
|
.unwrap_or_else(|| PathBuf::from(&config.changelog.path));
|
||||||
|
|
||||||
init_changelog(&path)?;
|
init_changelog(&path)?;
|
||||||
@@ -84,8 +87,7 @@ impl ChangelogCommand {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Determine output path
|
// Determine output path
|
||||||
let output_path = self.output.as_ref()
|
let output_path = self.output.clone()
|
||||||
.map(|p| p.clone())
|
|
||||||
.unwrap_or_else(|| PathBuf::from(&config.changelog.path));
|
.unwrap_or_else(|| PathBuf::from(&config.changelog.path));
|
||||||
|
|
||||||
// Determine format
|
// Determine format
|
||||||
@@ -148,7 +150,7 @@ impl ChangelogCommand {
|
|||||||
println!("{}", "─".repeat(60));
|
println!("{}", "─".repeat(60));
|
||||||
|
|
||||||
let confirm = Confirm::new()
|
let confirm = Confirm::new()
|
||||||
.with_prompt(&messages.write_to_file(&format!("{:?}", output_path)))
|
.with_prompt(messages.write_to_file(&format!("{:?}", output_path)))
|
||||||
.default(true)
|
.default(true)
|
||||||
.interact()?;
|
.interact()?;
|
||||||
|
|
||||||
@@ -208,7 +210,7 @@ impl ChangelogCommand {
|
|||||||
|
|
||||||
println!("{}", messages.ai_generating_changelog());
|
println!("{}", messages.ai_generating_changelog());
|
||||||
|
|
||||||
let generator = ContentGenerator::new(&manager).await?;
|
let generator = ContentGenerator::new_with_think(&manager, self.think).await?;
|
||||||
generator.generate_changelog_entry(version, commits, language).await
|
generator.generate_changelog_entry(version, commits, language).await
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -237,7 +239,8 @@ impl ChangelogCommand {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn translate_changelog_categories(&self, changelog: &str, language: Language) -> String {
|
fn translate_changelog_categories(&self, changelog: &str, language: Language) -> String {
|
||||||
let translated = changelog
|
|
||||||
|
changelog
|
||||||
.lines()
|
.lines()
|
||||||
.map(|line| {
|
.map(|line| {
|
||||||
if line.starts_with("## ") || line.starts_with("### ") {
|
if line.starts_with("## ") || line.starts_with("### ") {
|
||||||
@@ -253,7 +256,6 @@ impl ChangelogCommand {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join("\n");
|
.join("\n")
|
||||||
translated
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,6 +71,10 @@ pub struct CommitCommand {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
no_verify: bool,
|
no_verify: bool,
|
||||||
|
|
||||||
|
/// Enable thinking mode for this commit (override config)
|
||||||
|
#[arg(short = 't', long)]
|
||||||
|
think: bool,
|
||||||
|
|
||||||
/// Skip interactive prompts
|
/// Skip interactive prompts
|
||||||
#[arg(short = 'y', long)]
|
#[arg(short = 'y', long)]
|
||||||
yes: bool,
|
yes: bool,
|
||||||
@@ -198,7 +202,7 @@ impl CommitCommand {
|
|||||||
if let Some(commit_oid) = result {
|
if let Some(commit_oid) = result {
|
||||||
println!("{} {}", messages.commit_created().green().bold(), commit_oid.to_string()[..8].to_string().cyan());
|
println!("{} {}", messages.commit_created().green().bold(), commit_oid.to_string()[..8].to_string().cyan());
|
||||||
} else {
|
} else {
|
||||||
println!("{} {}", messages.commit_amended().green().bold(), "successfully");
|
println!("{} successfully", messages.commit_amended().green().bold());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Push after commit if requested or ask user
|
// Push after commit if requested or ask user
|
||||||
@@ -258,7 +262,7 @@ impl CommitCommand {
|
|||||||
async fn generate_commit(&self, repo: &GitRepo, format: CommitFormat, messages: &Messages) -> Result<String> {
|
async fn generate_commit(&self, repo: &GitRepo, format: CommitFormat, messages: &Messages) -> Result<String> {
|
||||||
let manager = ConfigManager::new()?;
|
let manager = ConfigManager::new()?;
|
||||||
|
|
||||||
let generator = ContentGenerator::new(&manager).await
|
let generator = ContentGenerator::new_with_think(&manager, self.think).await
|
||||||
.context("Failed to initialize LLM. Use --manual for manual commit.")?;
|
.context("Failed to initialize LLM. Use --manual for manual commit.")?;
|
||||||
|
|
||||||
println!("{}", messages.ai_analyzing());
|
println!("{}", messages.ai_analyzing());
|
||||||
|
|||||||
@@ -393,8 +393,15 @@ impl ConfigCommand {
|
|||||||
let timeout: u64 = value.parse()?;
|
let timeout: u64 = value.parse()?;
|
||||||
manager.config_mut().llm.timeout = timeout;
|
manager.config_mut().llm.timeout = timeout;
|
||||||
}
|
}
|
||||||
|
"llm.thinking_enabled" => {
|
||||||
|
manager.config_mut().llm.thinking_enabled = value == "true";
|
||||||
|
}
|
||||||
|
"llm.thinking_budget_tokens" => {
|
||||||
|
let budget: u32 = value.parse()?;
|
||||||
|
manager.config_mut().llm.thinking_budget_tokens = Some(budget);
|
||||||
|
}
|
||||||
"llm.api_key_storage" => {
|
"llm.api_key_storage" => {
|
||||||
let valid_values = vec!["keyring", "config", "environment"];
|
let valid_values = ["keyring", "config", "environment"];
|
||||||
if !valid_values.contains(&value) {
|
if !valid_values.contains(&value) {
|
||||||
bail!("Invalid value: {}. Use: {}", value, valid_values.join(", "));
|
bail!("Invalid value: {}. Use: {}", value, valid_values.join(", "));
|
||||||
}
|
}
|
||||||
@@ -433,6 +440,8 @@ impl ConfigCommand {
|
|||||||
"llm.max_tokens" => config.llm.max_tokens.to_string(),
|
"llm.max_tokens" => config.llm.max_tokens.to_string(),
|
||||||
"llm.temperature" => config.llm.temperature.to_string(),
|
"llm.temperature" => config.llm.temperature.to_string(),
|
||||||
"llm.timeout" => config.llm.timeout.to_string(),
|
"llm.timeout" => config.llm.timeout.to_string(),
|
||||||
|
"llm.thinking_enabled" => config.llm.thinking_enabled.to_string(),
|
||||||
|
"llm.thinking_budget_tokens" => config.llm.thinking_budget_tokens.map(|v| v.to_string()).unwrap_or_else(|| "none".to_string()),
|
||||||
"llm.api_key_storage" => config.llm.api_key_storage.clone(),
|
"llm.api_key_storage" => config.llm.api_key_storage.clone(),
|
||||||
"commit.format" => config.commit.format.to_string(),
|
"commit.format" => config.commit.format.to_string(),
|
||||||
"commit.auto_generate" => config.commit.auto_generate.to_string(),
|
"commit.auto_generate" => config.commit.auto_generate.to_string(),
|
||||||
@@ -681,15 +690,13 @@ impl ConfigCommand {
|
|||||||
l.to_string()
|
l.to_string()
|
||||||
} else {
|
} else {
|
||||||
println!("{}", "Select Output Language:".bold());
|
println!("{}", "Select Output Language:".bold());
|
||||||
let languages = vec![
|
let languages = [("en", "English"),
|
||||||
("en", "English"),
|
|
||||||
("zh", "中文"),
|
("zh", "中文"),
|
||||||
("ja", "日本語"),
|
("ja", "日本語"),
|
||||||
("ko", "한국어"),
|
("ko", "한국어"),
|
||||||
("es", "Español"),
|
("es", "Español"),
|
||||||
("fr", "Français"),
|
("fr", "Français"),
|
||||||
("de", "Deutsch"),
|
("de", "Deutsch")];
|
||||||
];
|
|
||||||
|
|
||||||
let lang_names: Vec<&str> = languages.iter().map(|(_, n)| *n).collect();
|
let lang_names: Vec<&str> = languages.iter().map(|(_, n)| *n).collect();
|
||||||
let idx = Select::new()
|
let idx = Select::new()
|
||||||
@@ -879,8 +886,8 @@ impl ConfigCommand {
|
|||||||
manager.import(&config_content)?;
|
manager.import(&config_content)?;
|
||||||
manager.save()?;
|
manager.save()?;
|
||||||
|
|
||||||
if let (Some(pats), Some(pwd)) = (encrypted_pats, pwd) {
|
if let (Some(pats), Some(pwd)) = (encrypted_pats, pwd)
|
||||||
if !pats.is_empty() {
|
&& !pats.is_empty() {
|
||||||
println!();
|
println!();
|
||||||
println!("{}", "Importing Personal Access Tokens...".bold());
|
println!("{}", "Importing Personal Access Tokens...".bold());
|
||||||
|
|
||||||
@@ -951,7 +958,6 @@ impl ConfigCommand {
|
|||||||
println!("{} {} token(s) failed to import", "⚠".yellow(), failed_count);
|
println!("{} {} token(s) failed to import", "⚠".yellow(), failed_count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
println!("{} Configuration imported from {}", "✓".green(), file);
|
println!("{} Configuration imported from {}", "✓".green(), file);
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -969,31 +975,35 @@ impl ConfigCommand {
|
|||||||
println!("Ollama models (local):");
|
println!("Ollama models (local):");
|
||||||
println!(" llama2, llama2-uncensored, llama2:13b");
|
println!(" llama2, llama2-uncensored, llama2:13b");
|
||||||
println!(" codellama, codellama:34b");
|
println!(" codellama, codellama:34b");
|
||||||
println!(" mistral, mixtral");
|
println!(" mistral, mixtral, phi, gemma");
|
||||||
println!(" phi, gemma");
|
|
||||||
println!("\nRun 'ollama list' to see installed models");
|
println!("\nRun 'ollama list' to see installed models");
|
||||||
}
|
}
|
||||||
"openai" => {
|
"openai" => {
|
||||||
println!("OpenAI models:");
|
println!("OpenAI models:");
|
||||||
println!(" gpt-4, gpt-4-turbo, gpt-4o");
|
println!(" o-series (reasoning): o4-mini, o3, o3-mini, o1, o1-mini, o1-pro");
|
||||||
println!(" gpt-3.5-turbo, gpt-3.5-turbo-16k");
|
println!(" GPT-4: gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, gpt-4o, gpt-4o-mini");
|
||||||
|
println!(" Legacy: gpt-4-turbo, gpt-4, gpt-3.5-turbo");
|
||||||
|
println!("\nUse --think/-t with o-series models for reasoning mode.");
|
||||||
}
|
}
|
||||||
"anthropic" => {
|
"anthropic" => {
|
||||||
println!("Anthropic Claude models:");
|
println!("Anthropic Claude models:");
|
||||||
println!(" claude-3-opus-20240229");
|
println!(" Claude 4 (thinking): claude-opus-4-7, claude-sonnet-4-6, claude-haiku-4-5");
|
||||||
println!(" claude-3-sonnet-20240229");
|
println!(" Claude 3.5: claude-3-opus-20240229, claude-3-sonnet-20240229, claude-3-haiku-20240307");
|
||||||
println!(" claude-3-haiku-20240307");
|
println!(" Legacy: claude-2.1, claude-2.0, claude-instant-1.2");
|
||||||
|
println!("\nUse --think/-t with Claude 4 models for extended thinking.");
|
||||||
}
|
}
|
||||||
"kimi" => {
|
"kimi" => {
|
||||||
println!("Kimi (Moonshot AI) models:");
|
println!("Kimi (Moonshot AI) models:");
|
||||||
println!(" moonshot-v1-8k");
|
println!(" K2 (thinking): kimi-k2.6, kimi-k2.5, kimi-k2-thinking, kimi-k2-thinking-turbo");
|
||||||
println!(" moonshot-v1-32k");
|
println!(" K2 instruct: kimi-k2-instruct, kimi-k2-instruct-0905");
|
||||||
println!(" moonshot-v1-128k");
|
println!(" Legacy: moonshot-v1-8k, moonshot-v1-32k, moonshot-v1-128k");
|
||||||
|
println!("\nUse --think/-t with K2 models for thinking mode.");
|
||||||
}
|
}
|
||||||
"deepseek" => {
|
"deepseek" => {
|
||||||
println!("DeepSeek models:");
|
println!("DeepSeek models:");
|
||||||
println!(" deepseek-chat");
|
println!(" V4: deepseek-v4-flash, deepseek-v4-pro");
|
||||||
println!(" deepseek-reasoner");
|
println!(" Legacy: deepseek-chat, deepseek-reasoner (deprecated 2026-07-24)");
|
||||||
|
println!("\nUse --think/-t for reasoning mode.");
|
||||||
}
|
}
|
||||||
"openrouter" => {
|
"openrouter" => {
|
||||||
println!("OpenRouter models (examples):");
|
println!("OpenRouter models (examples):");
|
||||||
|
|||||||
@@ -102,15 +102,13 @@ impl InitCommand {
|
|||||||
println!("\n{}", messages.setup_profile().bold());
|
println!("\n{}", messages.setup_profile().bold());
|
||||||
|
|
||||||
println!("\n{}", messages.select_output_language().bold());
|
println!("\n{}", messages.select_output_language().bold());
|
||||||
let languages = vec![
|
let languages = [Language::English,
|
||||||
Language::English,
|
|
||||||
Language::Chinese,
|
Language::Chinese,
|
||||||
Language::Japanese,
|
Language::Japanese,
|
||||||
Language::Korean,
|
Language::Korean,
|
||||||
Language::Spanish,
|
Language::Spanish,
|
||||||
Language::French,
|
Language::French,
|
||||||
Language::German,
|
Language::German];
|
||||||
];
|
|
||||||
let language_names: Vec<String> = languages.iter().map(|l| l.display_name().to_string()).collect();
|
let language_names: Vec<String> = languages.iter().map(|l| l.display_name().to_string()).collect();
|
||||||
let language_idx = Select::new()
|
let language_idx = Select::new()
|
||||||
.items(&language_names)
|
.items(&language_names)
|
||||||
@@ -297,12 +295,11 @@ impl InitCommand {
|
|||||||
manager.set_llm_model(model);
|
manager.set_llm_model(model);
|
||||||
manager.set_llm_base_url(base_url);
|
manager.set_llm_base_url(base_url);
|
||||||
|
|
||||||
if let Some(key) = api_key {
|
if let Some(key) = api_key
|
||||||
if provider_needs_api_key(&provider) {
|
&& provider_needs_api_key(&provider) {
|
||||||
manager.set_api_key(&key)?;
|
manager.set_api_key(&key)?;
|
||||||
println!("\n{} {}", "✓".green(), "API key stored securely in system keyring.".green());
|
println!("\n{} {}", "✓".green(), "API key stored securely in system keyring.".green());
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -260,7 +260,7 @@ impl ProfileCommand {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let confirm = Confirm::new()
|
let confirm = Confirm::new()
|
||||||
.with_prompt(&format!("Are you sure you want to remove profile '{}'?", name))
|
.with_prompt(format!("Are you sure you want to remove profile '{}'?", name))
|
||||||
.default(false)
|
.default(false)
|
||||||
.interact()?;
|
.interact()?;
|
||||||
|
|
||||||
@@ -564,7 +564,7 @@ impl ProfileCommand {
|
|||||||
|
|
||||||
if manager.has_profile(&profile_name) {
|
if manager.has_profile(&profile_name) {
|
||||||
let overwrite = Confirm::new()
|
let overwrite = Confirm::new()
|
||||||
.with_prompt(&format!("Profile '{}' already exists. Overwrite?", profile_name))
|
.with_prompt(format!("Profile '{}' already exists. Overwrite?", profile_name))
|
||||||
.default(false)
|
.default(false)
|
||||||
.interact()?;
|
.interact()?;
|
||||||
if !overwrite {
|
if !overwrite {
|
||||||
@@ -575,7 +575,7 @@ impl ProfileCommand {
|
|||||||
|
|
||||||
let description: String = Input::new()
|
let description: String = Input::new()
|
||||||
.with_prompt("Description (optional)")
|
.with_prompt("Description (optional)")
|
||||||
.default(format!("Imported from existing git config"))
|
.default("Imported from existing git config".to_string())
|
||||||
.allow_empty(true)
|
.allow_empty(true)
|
||||||
.interact_text()?;
|
.interact_text()?;
|
||||||
|
|
||||||
@@ -849,7 +849,7 @@ impl ProfileCommand {
|
|||||||
println!("{}", "─".repeat(40));
|
println!("{}", "─".repeat(40));
|
||||||
|
|
||||||
let token_value: String = Input::new()
|
let token_value: String = Input::new()
|
||||||
.with_prompt(&format!("Token for {}", service))
|
.with_prompt(format!("Token for {}", service))
|
||||||
.interact_text()?;
|
.interact_text()?;
|
||||||
|
|
||||||
let token_type_options = vec!["Personal", "OAuth", "Deploy", "App"];
|
let token_type_options = vec!["Personal", "OAuth", "Deploy", "App"];
|
||||||
@@ -895,7 +895,7 @@ impl ProfileCommand {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let confirm = Confirm::new()
|
let confirm = Confirm::new()
|
||||||
.with_prompt(&format!("Remove token '{}' from profile '{}'?", service, profile_name))
|
.with_prompt(format!("Remove token '{}' from profile '{}'?", service, profile_name))
|
||||||
.default(false)
|
.default(false)
|
||||||
.interact()?;
|
.interact()?;
|
||||||
|
|
||||||
|
|||||||
@@ -56,6 +56,10 @@ pub struct TagCommand {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
dry_run: bool,
|
dry_run: bool,
|
||||||
|
|
||||||
|
/// Enable thinking mode for this tag (override config)
|
||||||
|
#[arg(short = 't', long)]
|
||||||
|
think: bool,
|
||||||
|
|
||||||
/// Skip interactive prompts
|
/// Skip interactive prompts
|
||||||
#[arg(short = 'y', long)]
|
#[arg(short = 'y', long)]
|
||||||
yes: bool,
|
yes: bool,
|
||||||
@@ -285,7 +289,7 @@ impl TagCommand {
|
|||||||
|
|
||||||
println!("{}", messages.ai_generating_tag(commits.len()));
|
println!("{}", messages.ai_generating_tag(commits.len()));
|
||||||
|
|
||||||
let generator = ContentGenerator::new(&manager).await?;
|
let generator = ContentGenerator::new_with_think(&manager, self.think).await?;
|
||||||
generator.generate_tag_message(version, &commits, language).await
|
generator.generate_tag_message(version, &commits, language).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -136,11 +136,10 @@ impl ConfigManager {
|
|||||||
|
|
||||||
/// Set default profile
|
/// Set default profile
|
||||||
pub fn set_default_profile(&mut self, name: Option<String>) -> Result<()> {
|
pub fn set_default_profile(&mut self, name: Option<String>) -> Result<()> {
|
||||||
if let Some(ref n) = name {
|
if let Some(ref n) = name
|
||||||
if !self.config.profiles.contains_key(n) {
|
&& !self.config.profiles.contains_key(n) {
|
||||||
bail!("Profile '{}' does not exist", n);
|
bail!("Profile '{}' does not exist", n);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
self.config.default_profile = name;
|
self.config.default_profile = name;
|
||||||
self.modified = true;
|
self.modified = true;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -111,9 +111,13 @@ pub struct LlmConfig {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
|
|
||||||
/// Enable thinking/reasoning mode (deepseek, kimi)
|
/// Enable thinking/reasoning mode (deepseek, kimi, anthropic)
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub thinking_enabled: bool,
|
pub thinking_enabled: bool,
|
||||||
|
|
||||||
|
/// Budget tokens for thinking mode (Anthropic Claude 4)
|
||||||
|
#[serde(default)]
|
||||||
|
pub thinking_budget_tokens: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_api_key_storage() -> String {
|
fn default_api_key_storage() -> String {
|
||||||
@@ -132,6 +136,7 @@ impl Default for LlmConfig {
|
|||||||
api_key_storage: default_api_key_storage(),
|
api_key_storage: default_api_key_storage(),
|
||||||
api_key: None,
|
api_key: None,
|
||||||
thinking_enabled: false,
|
thinking_enabled: false,
|
||||||
|
thinking_budget_tokens: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -119,9 +119,7 @@ impl GitProfile {
|
|||||||
|
|
||||||
/// Get signing key (from GPG config or direct)
|
/// Get signing key (from GPG config or direct)
|
||||||
pub fn signing_key(&self) -> Option<&str> {
|
pub fn signing_key(&self) -> Option<&str> {
|
||||||
self.signing_key
|
self.signing_key.as_deref()
|
||||||
.as_ref()
|
|
||||||
.map(|s| s.as_str())
|
|
||||||
.or_else(|| self.gpg.as_ref().map(|g| g.key_id.as_str()))
|
.or_else(|| self.gpg.as_ref().map(|g| g.key_id.as_str()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,8 +173,8 @@ impl GitProfile {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref ssh) = self.ssh {
|
if let Some(ref ssh) = self.ssh
|
||||||
if let Some(ref key_path) = ssh.private_key_path {
|
&& let Some(ref key_path) = ssh.private_key_path {
|
||||||
let path_str = key_path.display().to_string();
|
let path_str = key_path.display().to_string();
|
||||||
#[cfg(target_os = "windows")]
|
#[cfg(target_os = "windows")]
|
||||||
{
|
{
|
||||||
@@ -189,7 +187,6 @@ impl GitProfile {
|
|||||||
&format!("ssh -i '{}'", path_str))?;
|
&format!("ssh -i '{}'", path_str))?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -213,8 +210,8 @@ impl GitProfile {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref ssh) = self.ssh {
|
if let Some(ref ssh) = self.ssh
|
||||||
if let Some(ref key_path) = ssh.private_key_path {
|
&& let Some(ref key_path) = ssh.private_key_path {
|
||||||
let path_str = key_path.display().to_string();
|
let path_str = key_path.display().to_string();
|
||||||
#[cfg(target_os = "windows")]
|
#[cfg(target_os = "windows")]
|
||||||
{
|
{
|
||||||
@@ -227,7 +224,6 @@ impl GitProfile {
|
|||||||
&format!("ssh -i '{}'", path_str))?;
|
&format!("ssh -i '{}'", path_str))?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -264,8 +260,8 @@ impl GitProfile {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(profile_key) = self.signing_key() {
|
if let Some(profile_key) = self.signing_key()
|
||||||
if git_signing_key.as_deref() != Some(profile_key) {
|
&& git_signing_key.as_deref() != Some(profile_key) {
|
||||||
comparison.matches = false;
|
comparison.matches = false;
|
||||||
comparison.differences.push(ConfigDifference {
|
comparison.differences.push(ConfigDifference {
|
||||||
key: "user.signingkey".to_string(),
|
key: "user.signingkey".to_string(),
|
||||||
@@ -273,7 +269,6 @@ impl GitProfile {
|
|||||||
git_value: git_signing_key.unwrap_or_else(|| "<not set>".to_string()),
|
git_value: git_signing_key.unwrap_or_else(|| "<not set>".to_string()),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(comparison)
|
Ok(comparison)
|
||||||
}
|
}
|
||||||
@@ -281,6 +276,7 @@ impl GitProfile {
|
|||||||
|
|
||||||
/// Profile settings
|
/// Profile settings
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[derive(Default)]
|
||||||
pub struct ProfileSettings {
|
pub struct ProfileSettings {
|
||||||
/// Automatically sign commits
|
/// Automatically sign commits
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@@ -307,18 +303,6 @@ pub struct ProfileSettings {
|
|||||||
pub commit_template: Option<String>,
|
pub commit_template: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ProfileSettings {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
auto_sign_commits: false,
|
|
||||||
auto_sign_tags: false,
|
|
||||||
default_commit_format: None,
|
|
||||||
repo_patterns: vec![],
|
|
||||||
llm_provider: None,
|
|
||||||
commit_template: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// SSH configuration
|
/// SSH configuration
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -349,17 +333,15 @@ pub struct SshConfig {
|
|||||||
impl SshConfig {
|
impl SshConfig {
|
||||||
/// Validate SSH configuration
|
/// Validate SSH configuration
|
||||||
pub fn validate(&self) -> Result<()> {
|
pub fn validate(&self) -> Result<()> {
|
||||||
if let Some(ref path) = self.private_key_path {
|
if let Some(ref path) = self.private_key_path
|
||||||
if !path.exists() {
|
&& !path.exists() {
|
||||||
bail!("SSH private key does not exist: {:?}", path);
|
bail!("SSH private key does not exist: {:?}", path);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref path) = self.public_key_path {
|
if let Some(ref path) = self.public_key_path
|
||||||
if !path.exists() {
|
&& !path.exists() {
|
||||||
bail!("SSH public key does not exist: {:?}", path);
|
bail!("SSH public key does not exist: {:?}", path);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -495,7 +477,9 @@ impl TokenConfig {
|
|||||||
/// Token type
|
/// Token type
|
||||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
|
#[derive(Default)]
|
||||||
pub enum TokenType {
|
pub enum TokenType {
|
||||||
|
#[default]
|
||||||
None,
|
None,
|
||||||
Personal,
|
Personal,
|
||||||
OAuth,
|
OAuth,
|
||||||
@@ -503,11 +487,6 @@ pub enum TokenType {
|
|||||||
App,
|
App,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for TokenType {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for TokenType {
|
impl std::fmt::Display for TokenType {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
|||||||
@@ -12,15 +12,43 @@ pub struct ContentGenerator {
|
|||||||
impl ContentGenerator {
|
impl ContentGenerator {
|
||||||
/// Create new content generator
|
/// Create new content generator
|
||||||
pub async fn new(manager: &ConfigManager) -> Result<Self> {
|
pub async fn new(manager: &ConfigManager) -> Result<Self> {
|
||||||
let llm_client = LlmClient::from_config(manager).await?;
|
Self::new_with_think(manager, false).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create new content generator with thinking override
|
||||||
|
pub async fn new_with_think(manager: &ConfigManager, think_override: bool) -> Result<Self> {
|
||||||
|
let mut thinking_enabled = if think_override {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
manager.config().llm.thinking_enabled
|
||||||
|
};
|
||||||
|
|
||||||
|
// Validate thinking support per provider
|
||||||
|
if thinking_enabled {
|
||||||
|
let provider = manager.llm_provider();
|
||||||
|
if !Self::supports_thinking(provider) {
|
||||||
|
eprintln!(
|
||||||
|
"Warning: Provider '{}' does not support thinking mode. \
|
||||||
|
Disabling thinking for this invocation.",
|
||||||
|
provider
|
||||||
|
);
|
||||||
|
thinking_enabled = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let llm_client = LlmClient::from_config_with_think(manager, thinking_enabled).await?;
|
||||||
|
|
||||||
if !llm_client.is_available().await {
|
if !llm_client.is_available().await {
|
||||||
anyhow::bail!("LLM provider '{}' is not available", manager.llm_provider());
|
anyhow::bail!("LLM provider '{}' is not available", manager.llm_provider());
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Self { llm_client })
|
Ok(Self { llm_client })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_thinking(provider: &str) -> bool {
|
||||||
|
matches!(provider, "deepseek" | "kimi" | "anthropic" | "openai")
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate commit message from diff
|
/// Generate commit message from diff
|
||||||
pub async fn generate_commit_message(
|
pub async fn generate_commit_message(
|
||||||
&self,
|
&self,
|
||||||
|
|||||||
@@ -234,7 +234,7 @@ impl ChangelogGenerator {
|
|||||||
_date: DateTime<Utc>,
|
_date: DateTime<Utc>,
|
||||||
commits: &[CommitInfo],
|
commits: &[CommitInfo],
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
let mut output = format!("## What's Changed\n\n");
|
let mut output = "## What's Changed\n\n".to_string();
|
||||||
|
|
||||||
// Group by type
|
// Group by type
|
||||||
let mut features = vec![];
|
let mut features = vec![];
|
||||||
@@ -427,19 +427,16 @@ pub fn parse_versions(changelog: &str) -> Vec<(String, String)> {
|
|||||||
let mut versions = vec![];
|
let mut versions = vec![];
|
||||||
|
|
||||||
for line in changelog.lines() {
|
for line in changelog.lines() {
|
||||||
if line.starts_with("## [") {
|
if line.starts_with("## [")
|
||||||
if let Some(start) = line.find('[') {
|
&& let Some(start) = line.find('[')
|
||||||
if let Some(end) = line.find(']') {
|
&& let Some(end) = line.find(']') {
|
||||||
let version = &line[start + 1..end];
|
let version = &line[start + 1..end];
|
||||||
if version != "Unreleased" {
|
if version != "Unreleased"
|
||||||
if let Some(date_start) = line.find(" - ") {
|
&& let Some(date_start) = line.find(" - ") {
|
||||||
let date = &line[date_start + 3..].trim();
|
let date = &line[date_start + 3..].trim();
|
||||||
versions.push((version.to_string(), date.to_string()));
|
versions.push((version.to_string(), date.to_string()));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
versions
|
versions
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ use anyhow::{bail, Context, Result};
|
|||||||
use git2::{Repository, Signature, StatusOptions, Config, Oid, ObjectType};
|
use git2::{Repository, Signature, StatusOptions, Config, Oid, ObjectType};
|
||||||
use std::path::{Path, PathBuf, Component};
|
use std::path::{Path, PathBuf, Component};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tempfile;
|
|
||||||
|
|
||||||
pub mod changelog;
|
pub mod changelog;
|
||||||
pub mod commit;
|
pub mod commit;
|
||||||
@@ -15,16 +14,14 @@ fn normalize_path_for_git2(path: &Path) -> PathBuf {
|
|||||||
#[cfg(target_os = "windows")]
|
#[cfg(target_os = "windows")]
|
||||||
{
|
{
|
||||||
let path_str = path.to_string_lossy();
|
let path_str = path.to_string_lossy();
|
||||||
if path_str.starts_with(r"\\?\") {
|
if path_str.starts_with(r"\\?\")
|
||||||
if let Some(stripped) = path_str.strip_prefix(r"\\?\") {
|
&& let Some(stripped) = path_str.strip_prefix(r"\\?\") {
|
||||||
normalized = PathBuf::from(stripped);
|
normalized = PathBuf::from(stripped);
|
||||||
}
|
}
|
||||||
}
|
if path_str.starts_with(r"\\?\UNC\")
|
||||||
if path_str.starts_with(r"\\?\UNC\") {
|
&& let Some(stripped) = path_str.strip_prefix(r"\\?\UNC\") {
|
||||||
if let Some(stripped) = path_str.strip_prefix(r"\\?\UNC\") {
|
|
||||||
normalized = PathBuf::from(format!(r"\\{}", stripped));
|
normalized = PathBuf::from(format!(r"\\{}", stripped));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
normalized
|
normalized
|
||||||
@@ -75,7 +72,7 @@ fn try_open_repo_with_git2(path: &Path) -> Result<Repository> {
|
|||||||
let discover_opts = git2::RepositoryOpenFlags::empty();
|
let discover_opts = git2::RepositoryOpenFlags::empty();
|
||||||
let ceiling_dirs: [&str; 0] = [];
|
let ceiling_dirs: [&str; 0] = [];
|
||||||
|
|
||||||
let repo = Repository::open_ext(&normalized, discover_opts, &ceiling_dirs)
|
let repo = Repository::open_ext(&normalized, discover_opts, ceiling_dirs)
|
||||||
.or_else(|_| Repository::discover(&normalized))
|
.or_else(|_| Repository::discover(&normalized))
|
||||||
.or_else(|_| Repository::open(&normalized));
|
.or_else(|_| Repository::open(&normalized));
|
||||||
|
|
||||||
@@ -84,7 +81,7 @@ fn try_open_repo_with_git2(path: &Path) -> Result<Repository> {
|
|||||||
|
|
||||||
fn try_open_repo_with_git_cli(path: &Path) -> Result<Repository> {
|
fn try_open_repo_with_git_cli(path: &Path) -> Result<Repository> {
|
||||||
let output = std::process::Command::new("git")
|
let output = std::process::Command::new("git")
|
||||||
.args(&["rev-parse", "--show-toplevel"])
|
.args(["rev-parse", "--show-toplevel"])
|
||||||
.current_dir(path)
|
.current_dir(path)
|
||||||
.output()
|
.output()
|
||||||
.context("Failed to execute git command")?;
|
.context("Failed to execute git command")?;
|
||||||
@@ -330,7 +327,7 @@ impl GitRepo {
|
|||||||
pub fn get_staged_diff(&self) -> Result<String> {
|
pub fn get_staged_diff(&self) -> Result<String> {
|
||||||
// Use git CLI to get staged diff for better compatibility
|
// Use git CLI to get staged diff for better compatibility
|
||||||
let output = std::process::Command::new("git")
|
let output = std::process::Command::new("git")
|
||||||
.args(&["diff", "--cached"])
|
.args(["diff", "--cached"])
|
||||||
.current_dir(&self.path)
|
.current_dir(&self.path)
|
||||||
.output()
|
.output()
|
||||||
.with_context(|| "Failed to get staged diff with git command")?;
|
.with_context(|| "Failed to get staged diff with git command")?;
|
||||||
@@ -509,11 +506,10 @@ impl GitRepo {
|
|||||||
let mut files = vec![];
|
let mut files = vec![];
|
||||||
for entry in statuses.iter() {
|
for entry in statuses.iter() {
|
||||||
let status = entry.status();
|
let status = entry.status();
|
||||||
if status.is_index_new() || status.is_index_modified() || status.is_index_deleted() || status.is_index_renamed() || status.is_index_typechange() {
|
if (status.is_index_new() || status.is_index_modified() || status.is_index_deleted() || status.is_index_renamed() || status.is_index_typechange())
|
||||||
if let Some(path) = entry.path() {
|
&& let Some(path) = entry.path() {
|
||||||
files.push(path.to_string());
|
files.push(path.to_string());
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(files)
|
Ok(files)
|
||||||
@@ -542,7 +538,7 @@ impl GitRepo {
|
|||||||
pub fn stage_all(&self) -> Result<()> {
|
pub fn stage_all(&self) -> Result<()> {
|
||||||
// Use git command for reliable staging (handles all edge cases)
|
// Use git command for reliable staging (handles all edge cases)
|
||||||
let output = std::process::Command::new("git")
|
let output = std::process::Command::new("git")
|
||||||
.args(&["add", "-A"])
|
.args(["add", "-A"])
|
||||||
.current_dir(&self.path)
|
.current_dir(&self.path)
|
||||||
.output()
|
.output()
|
||||||
.with_context(|| "Failed to stage changes with git command")?;
|
.with_context(|| "Failed to stage changes with git command")?;
|
||||||
@@ -626,7 +622,7 @@ impl GitRepo {
|
|||||||
std::fs::write(temp_file.path(), message)?;
|
std::fs::write(temp_file.path(), message)?;
|
||||||
|
|
||||||
let output = std::process::Command::new("git")
|
let output = std::process::Command::new("git")
|
||||||
.args(&["commit", "-S", "-F", temp_file.path().to_str().unwrap()])
|
.args(["commit", "-S", "-F", temp_file.path().to_str().unwrap()])
|
||||||
.current_dir(&self.path)
|
.current_dir(&self.path)
|
||||||
.output()?;
|
.output()?;
|
||||||
|
|
||||||
@@ -801,7 +797,7 @@ impl GitRepo {
|
|||||||
/// Create signed tag using git CLI
|
/// Create signed tag using git CLI
|
||||||
fn create_signed_tag_with_git2(&self, name: &str, message: &str, _signature: &Signature, _target_id: Oid) -> Result<()> {
|
fn create_signed_tag_with_git2(&self, name: &str, message: &str, _signature: &Signature, _target_id: Oid) -> Result<()> {
|
||||||
let output = std::process::Command::new("git")
|
let output = std::process::Command::new("git")
|
||||||
.args(&["tag", "-s", name, "-m", message])
|
.args(["tag", "-s", name, "-m", message])
|
||||||
.current_dir(&self.path)
|
.current_dir(&self.path)
|
||||||
.output()?;
|
.output()?;
|
||||||
|
|
||||||
@@ -827,7 +823,7 @@ impl GitRepo {
|
|||||||
/// Push to remote
|
/// Push to remote
|
||||||
pub fn push(&self, remote: &str, refspec: &str) -> Result<()> {
|
pub fn push(&self, remote: &str, refspec: &str) -> Result<()> {
|
||||||
let output = std::process::Command::new("git")
|
let output = std::process::Command::new("git")
|
||||||
.args(&["push", remote, refspec])
|
.args(["push", remote, refspec])
|
||||||
.current_dir(&self.path)
|
.current_dir(&self.path)
|
||||||
.output()?;
|
.output()?;
|
||||||
|
|
||||||
@@ -856,7 +852,7 @@ impl GitRepo {
|
|||||||
pub fn status_summary(&self) -> Result<StatusSummary> {
|
pub fn status_summary(&self) -> Result<StatusSummary> {
|
||||||
// Use git CLI for more reliable status detection
|
// Use git CLI for more reliable status detection
|
||||||
let output = std::process::Command::new("git")
|
let output = std::process::Command::new("git")
|
||||||
.args(&["status", "--porcelain"])
|
.args(["status", "--porcelain"])
|
||||||
.current_dir(&self.path)
|
.current_dir(&self.path)
|
||||||
.output()
|
.output()
|
||||||
.with_context(|| "Failed to get status with git command")?;
|
.with_context(|| "Failed to get status with git command")?;
|
||||||
@@ -1015,20 +1011,17 @@ pub fn find_repo<P: AsRef<Path>>(start_path: P) -> Result<GitRepo> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(output) = std::process::Command::new("git")
|
if let Ok(output) = std::process::Command::new("git")
|
||||||
.args(&["rev-parse", "--show-toplevel"])
|
.args(["rev-parse", "--show-toplevel"])
|
||||||
.current_dir(&resolved_start)
|
.current_dir(&resolved_start)
|
||||||
.output()
|
.output()
|
||||||
{
|
&& output.status.success() {
|
||||||
if output.status.success() {
|
|
||||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||||
let git_root = stdout.trim();
|
let git_root = stdout.trim();
|
||||||
if !git_root.is_empty() {
|
if !git_root.is_empty()
|
||||||
if let Ok(repo) = GitRepo::open(git_root) {
|
&& let Ok(repo) = GitRepo::open(git_root) {
|
||||||
return Ok(repo);
|
return Ok(repo);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
let diagnosis = diagnose_repo_issue(&resolved_start);
|
let diagnosis = diagnose_repo_issue(&resolved_start);
|
||||||
|
|
||||||
|
|||||||
@@ -283,7 +283,7 @@ pub fn delete_tag(repo: &GitRepo, name: &str, remote: Option<&str>) -> Result<()
|
|||||||
|
|
||||||
let refspec = format!(":refs/tags/{}", name);
|
let refspec = format!(":refs/tags/{}", name);
|
||||||
let output = Command::new("git")
|
let output = Command::new("git")
|
||||||
.args(&["push", remote, &refspec])
|
.args(["push", remote, &refspec])
|
||||||
.current_dir(repo.path())
|
.current_dir(repo.path())
|
||||||
.output()?;
|
.output()?;
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
use super::thinking::ThinkingStateManager;
|
||||||
use super::{create_http_client, LlmProvider};
|
use super::{create_http_client, LlmProvider};
|
||||||
use anyhow::{bail, Context, Result};
|
use anyhow::{bail, Context, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
/// Anthropic Claude API client
|
/// Anthropic Claude API client
|
||||||
@@ -9,6 +11,12 @@ pub struct AnthropicClient {
|
|||||||
api_key: String,
|
api_key: String,
|
||||||
model: String,
|
model: String,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
|
thinking_enabled: bool,
|
||||||
|
thinking_budget_tokens: u32,
|
||||||
|
max_tokens: u32,
|
||||||
|
temperature: f32,
|
||||||
|
top_p: Option<f32>,
|
||||||
|
thinking_state: Option<Arc<ThinkingStateManager>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
@@ -17,24 +25,58 @@ struct MessagesRequest {
|
|||||||
max_tokens: u32,
|
max_tokens: u32,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
top_p: Option<f32>,
|
||||||
messages: Vec<AnthropicMessage>,
|
messages: Vec<AnthropicMessage>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
system: Option<String>,
|
system: Option<Vec<SystemContent>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
thinking: Option<ThinkingConfig>,
|
||||||
|
stream: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Clone)]
|
||||||
|
struct SystemContent {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
content_type: String,
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct ThinkingConfig {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
thinking_type: String,
|
||||||
|
budget_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
struct AnthropicMessage {
|
struct AnthropicMessage {
|
||||||
role: String,
|
role: String,
|
||||||
content: String,
|
content: AnthropicContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum AnthropicContent {
|
||||||
|
Text(String),
|
||||||
|
Blocks(Vec<ContentBlock>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
struct ContentBlock {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
content_type: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
text: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct MessagesResponse {
|
struct MessagesResponse {
|
||||||
content: Vec<ContentBlock>,
|
content: Vec<ResponseContentBlock>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ContentBlock {
|
struct ResponseContentBlock {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
content_type: String,
|
content_type: String,
|
||||||
text: String,
|
text: String,
|
||||||
@@ -52,31 +94,112 @@ struct AnthropicError {
|
|||||||
message: String,
|
message: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Streaming SSE event structures ---
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct SseEvent {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
event_type: String,
|
||||||
|
#[serde(default)]
|
||||||
|
message: Option<SseMessage>,
|
||||||
|
#[serde(default)]
|
||||||
|
index: Option<u32>,
|
||||||
|
#[serde(default)]
|
||||||
|
content_block: Option<SseContentBlock>,
|
||||||
|
#[serde(default)]
|
||||||
|
delta: Option<SseDelta>,
|
||||||
|
#[serde(default)]
|
||||||
|
usage: Option<SseUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct SseMessage {
|
||||||
|
#[serde(default)]
|
||||||
|
content: Option<Vec<SseContentBlock>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct SseContentBlock {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
content_type: String,
|
||||||
|
#[serde(default)]
|
||||||
|
thinking: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
text: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct SseDelta {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
delta_type: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
thinking: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
text: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct SseUsage {
|
||||||
|
#[serde(default)]
|
||||||
|
output_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
impl AnthropicClient {
|
impl AnthropicClient {
|
||||||
/// Create new Anthropic client
|
|
||||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||||
let client = create_http_client(Duration::from_secs(60))?;
|
let client = create_http_client(Duration::from_secs(60))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
api_key: api_key.to_string(),
|
api_key: api_key.to_string(),
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
client,
|
client,
|
||||||
|
thinking_enabled: false,
|
||||||
|
thinking_budget_tokens: 1024,
|
||||||
|
max_tokens: 500,
|
||||||
|
temperature: 0.7,
|
||||||
|
top_p: None,
|
||||||
|
thinking_state: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set timeout
|
|
||||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||||
self.client = create_http_client(timeout)?;
|
self.client = create_http_client(timeout)?;
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// List available models
|
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
||||||
|
self.thinking_enabled = enabled;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_thinking_budget_tokens(mut self, budget_tokens: u32) -> Self {
|
||||||
|
self.thinking_budget_tokens = budget_tokens;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||||
|
self.max_tokens = max_tokens;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||||
|
self.temperature = temperature;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||||
|
self.top_p = Some(top_p);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
|
||||||
|
self.thinking_state = Some(state);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||||
// Anthropic doesn't have a models API endpoint, return predefined list
|
|
||||||
Ok(ANTHROPIC_MODELS.iter().map(|&m| m.to_string()).collect())
|
Ok(ANTHROPIC_MODELS.iter().map(|&m| m.to_string()).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate API key
|
|
||||||
pub async fn validate_key(&self) -> Result<bool> {
|
pub async fn validate_key(&self) -> Result<bool> {
|
||||||
let url = "https://api.anthropic.com/v1/messages";
|
let url = "https://api.anthropic.com/v1/messages";
|
||||||
|
|
||||||
@@ -84,14 +207,18 @@ impl AnthropicClient {
|
|||||||
model: self.model.clone(),
|
model: self.model.clone(),
|
||||||
max_tokens: 5,
|
max_tokens: 5,
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
|
top_p: None,
|
||||||
messages: vec![AnthropicMessage {
|
messages: vec![AnthropicMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi".to_string(),
|
content: AnthropicContent::Text("Hi".to_string()),
|
||||||
}],
|
}],
|
||||||
system: None,
|
system: None,
|
||||||
|
thinking: None,
|
||||||
|
stream: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = self.client
|
let response = self
|
||||||
|
.client
|
||||||
.post(url)
|
.post(url)
|
||||||
.header("x-api-key", &self.api_key)
|
.header("x-api-key", &self.api_key)
|
||||||
.header("anthropic-version", "2023-06-01")
|
.header("anthropic-version", "2023-06-01")
|
||||||
@@ -124,25 +251,28 @@ impl LlmProvider for AnthropicClient {
|
|||||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||||
let messages = vec![AnthropicMessage {
|
let messages = vec![AnthropicMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: prompt.to_string(),
|
content: AnthropicContent::Text(prompt.to_string()),
|
||||||
}];
|
}];
|
||||||
|
|
||||||
self.messages_request(messages, None).await
|
self.messages_request_with_retry(messages, None).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||||
let messages = vec![AnthropicMessage {
|
let messages = vec![AnthropicMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: user.to_string(),
|
content: AnthropicContent::Text(user.to_string()),
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let system = if system.is_empty() {
|
let system = if system.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(system.to_string())
|
Some(vec![SystemContent {
|
||||||
|
content_type: "text".to_string(),
|
||||||
|
text: system.to_string(),
|
||||||
|
}])
|
||||||
};
|
};
|
||||||
|
|
||||||
self.messages_request(messages, system).await
|
self.messages_request_with_retry(messages, system).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn is_available(&self) -> bool {
|
async fn is_available(&self) -> bool {
|
||||||
@@ -155,22 +285,81 @@ impl LlmProvider for AnthropicClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicClient {
|
impl AnthropicClient {
|
||||||
|
async fn messages_request_with_retry(
|
||||||
|
&self,
|
||||||
|
messages: Vec<AnthropicMessage>,
|
||||||
|
system: Option<Vec<SystemContent>>,
|
||||||
|
) -> Result<String> {
|
||||||
|
let mut last_error = None;
|
||||||
|
|
||||||
|
for attempt in 1..=3 {
|
||||||
|
match self
|
||||||
|
.messages_request(messages.clone(), system.clone())
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(result) => return Ok(result),
|
||||||
|
Err(e) => {
|
||||||
|
let err_msg = e.to_string();
|
||||||
|
let is_retryable = err_msg.contains("timeout")
|
||||||
|
|| err_msg.contains("connection")
|
||||||
|
|| err_msg.contains("temporary")
|
||||||
|
|| err_msg.contains("5")
|
||||||
|
&& (err_msg.contains("500")
|
||||||
|
|| err_msg.contains("502")
|
||||||
|
|| err_msg.contains("503")
|
||||||
|
|| err_msg.contains("504"));
|
||||||
|
|
||||||
|
if !is_retryable || attempt == 3 {
|
||||||
|
last_error = Some(e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
|
||||||
|
}
|
||||||
|
|
||||||
async fn messages_request(
|
async fn messages_request(
|
||||||
&self,
|
&self,
|
||||||
messages: Vec<AnthropicMessage>,
|
messages: Vec<AnthropicMessage>,
|
||||||
system: Option<String>,
|
system: Option<Vec<SystemContent>>,
|
||||||
|
) -> Result<String> {
|
||||||
|
if self.thinking_enabled {
|
||||||
|
self.streaming_messages_request(messages, system).await
|
||||||
|
} else {
|
||||||
|
self.non_streaming_messages_request(messages, system).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn non_streaming_messages_request(
|
||||||
|
&self,
|
||||||
|
messages: Vec<AnthropicMessage>,
|
||||||
|
system: Option<Vec<SystemContent>>,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
let url = "https://api.anthropic.com/v1/messages";
|
let url = "https://api.anthropic.com/v1/messages";
|
||||||
|
|
||||||
|
let temperature = if self.temperature == 0.0 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.temperature)
|
||||||
|
};
|
||||||
|
|
||||||
let request = MessagesRequest {
|
let request = MessagesRequest {
|
||||||
model: self.model.clone(),
|
model: self.model.clone(),
|
||||||
max_tokens: 500,
|
max_tokens: self.max_tokens,
|
||||||
temperature: Some(0.7),
|
temperature,
|
||||||
|
top_p: self.top_p,
|
||||||
messages,
|
messages,
|
||||||
system,
|
system,
|
||||||
|
thinking: None,
|
||||||
|
stream: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = self.client
|
let response = self
|
||||||
|
.client
|
||||||
.post(url)
|
.post(url)
|
||||||
.header("x-api-key", &self.api_key)
|
.header("x-api-key", &self.api_key)
|
||||||
.header("anthropic-version", "2023-06-01")
|
.header("anthropic-version", "2023-06-01")
|
||||||
@@ -179,35 +368,205 @@ impl AnthropicClient {
|
|||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.context("Failed to send request to Anthropic")?;
|
.context("Failed to send request to Anthropic")?;
|
||||||
|
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
|
|
||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
let text = response.text().await.unwrap_or_default();
|
let text = response.text().await.unwrap_or_default();
|
||||||
|
|
||||||
// Try to parse error
|
|
||||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||||
bail!("Anthropic API error: {} ({})", error.error.message, error.error.error_type);
|
bail!(
|
||||||
|
"Anthropic API error: {} ({})",
|
||||||
|
error.error.message,
|
||||||
|
error.error.error_type
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
bail!("Anthropic API error: {} - {}", status, text);
|
bail!("Anthropic API error: {} - {}", status, text);
|
||||||
}
|
}
|
||||||
|
|
||||||
let result: MessagesResponse = response
|
let result: MessagesResponse = response
|
||||||
.json()
|
.json()
|
||||||
.await
|
.await
|
||||||
.context("Failed to parse Anthropic response")?;
|
.context("Failed to parse Anthropic response")?;
|
||||||
|
|
||||||
result.content
|
result
|
||||||
|
.content
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.find(|c| c.content_type == "text")
|
.find(|c| c.content_type == "text")
|
||||||
.map(|c| c.text.trim().to_string())
|
.map(|c| c.text.trim().to_string())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
.ok_or_else(|| anyhow::anyhow!("No text response from Anthropic"))
|
.ok_or_else(|| anyhow::anyhow!("No text response from Anthropic"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Streaming request for thinking mode, filters thinking content blocks
|
||||||
|
async fn streaming_messages_request(
|
||||||
|
&self,
|
||||||
|
messages: Vec<AnthropicMessage>,
|
||||||
|
system: Option<Vec<SystemContent>>,
|
||||||
|
) -> Result<String> {
|
||||||
|
let url = "https://api.anthropic.com/v1/messages";
|
||||||
|
|
||||||
|
let thinking = ThinkingConfig {
|
||||||
|
thinking_type: "enabled".to_string(),
|
||||||
|
budget_tokens: self.thinking_budget_tokens,
|
||||||
|
};
|
||||||
|
|
||||||
|
// max_tokens must exceed budget_tokens
|
||||||
|
let max_tokens = (self.max_tokens).max(self.thinking_budget_tokens + 100);
|
||||||
|
|
||||||
|
let request = MessagesRequest {
|
||||||
|
model: self.model.clone(),
|
||||||
|
max_tokens,
|
||||||
|
temperature: None, // must be omitted for thinking mode
|
||||||
|
top_p: None,
|
||||||
|
messages,
|
||||||
|
system,
|
||||||
|
thinking: Some(thinking),
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(url)
|
||||||
|
.header("x-api-key", &self.api_key)
|
||||||
|
.header("anthropic-version", "2023-06-01")
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Accept", "text/event-stream")
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.context("Failed to send streaming request to Anthropic")?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let text = response.text().await.unwrap_or_default();
|
||||||
|
|
||||||
|
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||||
|
bail!(
|
||||||
|
"Anthropic API error: {} ({})",
|
||||||
|
error.error.message,
|
||||||
|
error.error.error_type
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
bail!("Anthropic API error: {} - {}", status, text);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut content_buffer = String::new();
|
||||||
|
let mut in_thinking = false;
|
||||||
|
let mut has_reasoning = false;
|
||||||
|
let mut has_content = false;
|
||||||
|
|
||||||
|
let thinking_state = self.thinking_state.as_ref();
|
||||||
|
|
||||||
|
let mut byte_stream = response.bytes_stream();
|
||||||
|
let mut line_buffer = String::new();
|
||||||
|
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
|
||||||
|
while let Some(chunk) = byte_stream.next().await {
|
||||||
|
let chunk = chunk.context("Failed to read streaming response chunk")?;
|
||||||
|
let chunk_str =
|
||||||
|
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
|
||||||
|
|
||||||
|
line_buffer.push_str(&chunk_str);
|
||||||
|
|
||||||
|
while let Some(line_end) = line_buffer.find('\n') {
|
||||||
|
let line = line_buffer[..line_end].trim().to_string();
|
||||||
|
line_buffer = line_buffer[line_end + 1..].to_string();
|
||||||
|
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse SSE event line
|
||||||
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
|
if let Ok(event) = serde_json::from_str::<SseEvent>(data) {
|
||||||
|
match event.event_type.as_str() {
|
||||||
|
"content_block_start" => {
|
||||||
|
if let Some(ref block) = event.content_block {
|
||||||
|
if block.content_type == "thinking" {
|
||||||
|
in_thinking = true;
|
||||||
|
if !has_reasoning {
|
||||||
|
has_reasoning = true;
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
state.start_thinking();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"content_block_delta" => {
|
||||||
|
if let Some(ref delta) = event.delta {
|
||||||
|
// Thinking delta - ignore content but track state
|
||||||
|
if delta.thinking.is_some() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text delta - collect
|
||||||
|
if in_thinking && delta.text.is_some() {
|
||||||
|
// Transition from thinking to text
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
state.end_thinking();
|
||||||
|
}
|
||||||
|
in_thinking = false;
|
||||||
|
}
|
||||||
|
if let Some(ref text) = delta.text
|
||||||
|
&& !text.is_empty()
|
||||||
|
{
|
||||||
|
has_content = true;
|
||||||
|
content_buffer.push_str(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"content_block_stop" => {
|
||||||
|
if in_thinking {
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
state.end_thinking();
|
||||||
|
}
|
||||||
|
in_thinking = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure thinking state is ended
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
state.end_thinking();
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = content_buffer.trim().to_string();
|
||||||
|
|
||||||
|
if result.is_empty() {
|
||||||
|
if has_reasoning && !has_content {
|
||||||
|
bail!(
|
||||||
|
"Anthropic returned thinking content but no final answer. \
|
||||||
|
The model may have entered an incomplete thinking state. \
|
||||||
|
Please try again or disable thinking mode."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
bail!(
|
||||||
|
"No response from Anthropic. \
|
||||||
|
If thinking mode is enabled, try disabling it or ensure the model supports it."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Available Anthropic models
|
/// Available Anthropic models (Claude 4 series with extended thinking)
|
||||||
pub const ANTHROPIC_MODELS: &[&str] = &[
|
pub const ANTHROPIC_MODELS: &[&str] = &[
|
||||||
|
"claude-opus-4-7",
|
||||||
|
"claude-sonnet-4-6",
|
||||||
|
"claude-haiku-4-5",
|
||||||
|
// Legacy models
|
||||||
"claude-3-opus-20240229",
|
"claude-3-opus-20240229",
|
||||||
"claude-3-sonnet-20240229",
|
"claude-3-sonnet-20240229",
|
||||||
"claude-3-haiku-20240307",
|
"claude-3-haiku-20240307",
|
||||||
@@ -216,7 +575,6 @@ pub const ANTHROPIC_MODELS: &[&str] = &[
|
|||||||
"claude-instant-1.2",
|
"claude-instant-1.2",
|
||||||
];
|
];
|
||||||
|
|
||||||
/// Check if a model name is valid
|
|
||||||
pub fn is_valid_model(model: &str) -> bool {
|
pub fn is_valid_model(model: &str) -> bool {
|
||||||
ANTHROPIC_MODELS.contains(&model)
|
ANTHROPIC_MODELS.contains(&model)
|
||||||
}
|
}
|
||||||
@@ -226,8 +584,61 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_model_validation() {
|
fn test_model_validation_claude4() {
|
||||||
|
assert!(is_valid_model("claude-opus-4-7"));
|
||||||
|
assert!(is_valid_model("claude-sonnet-4-6"));
|
||||||
|
assert!(is_valid_model("claude-haiku-4-5"));
|
||||||
assert!(is_valid_model("claude-3-sonnet-20240229"));
|
assert!(is_valid_model("claude-3-sonnet-20240229"));
|
||||||
assert!(!is_valid_model("invalid-model"));
|
assert!(!is_valid_model("invalid-model"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_thinking_config_serialization() {
|
||||||
|
let config = ThinkingConfig {
|
||||||
|
thinking_type: "enabled".to_string(),
|
||||||
|
budget_tokens: 2048,
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&config).unwrap();
|
||||||
|
assert!(json.contains(r#""type":"enabled""#));
|
||||||
|
assert!(json.contains(r#""budget_tokens":2048"#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_system_content_serialization() {
|
||||||
|
let content = SystemContent {
|
||||||
|
content_type: "text".to_string(),
|
||||||
|
text: "You are helpful.".to_string(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&content).unwrap();
|
||||||
|
assert!(json.contains(r#""type":"text""#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sse_event_parsing_content_block_start() {
|
||||||
|
let json = r#"{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}"#;
|
||||||
|
let event: SseEvent = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(event.event_type, "content_block_start");
|
||||||
|
assert_eq!(
|
||||||
|
event.content_block.unwrap().content_type,
|
||||||
|
"thinking"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sse_event_parsing_text_delta() {
|
||||||
|
let json = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
|
||||||
|
let event: SseEvent = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(event.event_type, "content_block_delta");
|
||||||
|
assert_eq!(event.delta.unwrap().text, Some("Hello".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_anthropic_content_text() {
|
||||||
|
let msg = AnthropicMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: AnthropicContent::Text("Hello".to_string()),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
assert!(json.contains(r#""content":"Hello""#));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
use super::thinking::ThinkingStateManager;
|
||||||
use super::{create_http_client, LlmProvider};
|
use super::{create_http_client, LlmProvider};
|
||||||
use anyhow::{bail, Context, Result};
|
use anyhow::{bail, Context, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
/// DeepSeek API client
|
/// DeepSeek API client
|
||||||
@@ -11,6 +13,10 @@ pub struct DeepSeekClient {
|
|||||||
model: String,
|
model: String,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
thinking_enabled: bool,
|
thinking_enabled: bool,
|
||||||
|
reasoning_effort: Option<String>,
|
||||||
|
max_tokens: u32,
|
||||||
|
temperature: f32,
|
||||||
|
thinking_state: Option<Arc<ThinkingStateManager>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
@@ -21,9 +27,17 @@ struct ChatCompletionRequest {
|
|||||||
max_tokens: Option<u32>,
|
max_tokens: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
presence_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
frequency_penalty: Option<f32>,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
thinking: Option<ThinkingConfig>,
|
thinking: Option<ThinkingConfig>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
reasoning_effort: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
@@ -32,10 +46,12 @@ struct ThinkingConfig {
|
|||||||
thinking_type: String,
|
thinking_type: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
struct Message {
|
struct Message {
|
||||||
role: String,
|
role: String,
|
||||||
content: String,
|
content: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
reasoning_content: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -50,6 +66,29 @@ struct Choice {
|
|||||||
reasoning_content: Option<String>,
|
reasoning_content: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Streaming response structures ---
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct StreamChunk {
|
||||||
|
choices: Vec<StreamChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct StreamChoice {
|
||||||
|
delta: StreamDelta,
|
||||||
|
#[serde(default)]
|
||||||
|
finish_reason: Option<String>,
|
||||||
|
index: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Default)]
|
||||||
|
struct StreamDelta {
|
||||||
|
#[serde(default)]
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
reasoning_content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ErrorResponse {
|
struct ErrorResponse {
|
||||||
error: ApiError,
|
error: ApiError,
|
||||||
@@ -63,22 +102,24 @@ struct ApiError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl DeepSeekClient {
|
impl DeepSeekClient {
|
||||||
/// Create new DeepSeek client
|
|
||||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||||
let client = create_http_client(Duration::from_secs(60))?;
|
let client = create_http_client(Duration::from_secs(300))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
base_url: "https://api.deepseek.com/".to_string(),
|
base_url: "https://api.deepseek.com".to_string(),
|
||||||
api_key: api_key.to_string(),
|
api_key: api_key.to_string(),
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
client,
|
client,
|
||||||
thinking_enabled: false,
|
thinking_enabled: false,
|
||||||
|
reasoning_effort: None,
|
||||||
|
max_tokens: 500,
|
||||||
|
temperature: 0.7,
|
||||||
|
thinking_state: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create with custom base URL
|
|
||||||
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
||||||
let client = create_http_client(Duration::from_secs(60))?;
|
let client = create_http_client(Duration::from_secs(300))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
base_url: base_url.trim_end_matches('/').to_string(),
|
base_url: base_url.trim_end_matches('/').to_string(),
|
||||||
@@ -86,26 +127,48 @@ impl DeepSeekClient {
|
|||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
client,
|
client,
|
||||||
thinking_enabled: false,
|
thinking_enabled: false,
|
||||||
|
reasoning_effort: None,
|
||||||
|
max_tokens: 500,
|
||||||
|
temperature: 0.7,
|
||||||
|
thinking_state: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set timeout
|
|
||||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||||
self.client = create_http_client(timeout)?;
|
self.client = create_http_client(timeout)?;
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable or disable thinking mode
|
|
||||||
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
||||||
self.thinking_enabled = enabled;
|
self.thinking_enabled = enabled;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// List available models
|
pub fn with_reasoning_effort(mut self, effort: Option<String>) -> Self {
|
||||||
|
self.reasoning_effort = effort;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||||
|
self.max_tokens = max_tokens;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||||
|
self.temperature = temperature;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
|
||||||
|
self.thinking_state = Some(state);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||||
let url = format!("{}/models", self.base_url);
|
let url = format!("{}/models", self.base_url);
|
||||||
|
|
||||||
let response = self.client
|
let response = self
|
||||||
|
.client
|
||||||
.get(&url)
|
.get(&url)
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.send()
|
.send()
|
||||||
@@ -120,11 +183,11 @@ impl DeepSeekClient {
|
|||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct ModelsResponse {
|
struct ModelsResponse {
|
||||||
data: Vec<Model>,
|
data: Vec<ModelId>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Model {
|
struct ModelId {
|
||||||
id: String,
|
id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,7 +199,6 @@ impl DeepSeekClient {
|
|||||||
Ok(result.data.into_iter().map(|m| m.id).collect())
|
Ok(result.data.into_iter().map(|m| m.id).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate API key
|
|
||||||
pub async fn validate_key(&self) -> Result<bool> {
|
pub async fn validate_key(&self) -> Result<bool> {
|
||||||
match self.list_models().await {
|
match self.list_models().await {
|
||||||
Ok(_) => Ok(true),
|
Ok(_) => Ok(true),
|
||||||
@@ -155,14 +217,13 @@ impl DeepSeekClient {
|
|||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl LlmProvider for DeepSeekClient {
|
impl LlmProvider for DeepSeekClient {
|
||||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||||
let messages = vec![
|
let messages = vec![Message {
|
||||||
Message {
|
role: "user".to_string(),
|
||||||
role: "user".to_string(),
|
content: prompt.to_string(),
|
||||||
content: prompt.to_string(),
|
reasoning_content: None,
|
||||||
},
|
}];
|
||||||
];
|
|
||||||
|
|
||||||
self.chat_completion(messages).await
|
self.chat_completion_with_retry(messages).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||||
@@ -172,15 +233,17 @@ impl LlmProvider for DeepSeekClient {
|
|||||||
messages.push(Message {
|
messages.push(Message {
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: system.to_string(),
|
content: system.to_string(),
|
||||||
|
reasoning_content: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.push(Message {
|
messages.push(Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: user.to_string(),
|
content: user.to_string(),
|
||||||
|
reasoning_content: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
self.chat_completion(messages).await
|
self.chat_completion_with_retry(messages).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn is_available(&self) -> bool {
|
async fn is_available(&self) -> bool {
|
||||||
@@ -193,6 +256,38 @@ impl LlmProvider for DeepSeekClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl DeepSeekClient {
|
impl DeepSeekClient {
|
||||||
|
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
|
||||||
|
let mut last_error = None;
|
||||||
|
|
||||||
|
for attempt in 1..=3 {
|
||||||
|
match self.chat_completion(messages.clone()).await {
|
||||||
|
Ok(result) => return Ok(result),
|
||||||
|
Err(e) => {
|
||||||
|
let err_msg = e.to_string();
|
||||||
|
// 网络临时错误才重试
|
||||||
|
let is_retryable = err_msg.contains("timeout")
|
||||||
|
|| err_msg.contains("connection")
|
||||||
|
|| err_msg.contains("temporary")
|
||||||
|
|| err_msg.contains("5")
|
||||||
|
&& (err_msg.contains("500")
|
||||||
|
|| err_msg.contains("502")
|
||||||
|
|| err_msg.contains("503")
|
||||||
|
|| err_msg.contains("504"));
|
||||||
|
|
||||||
|
if !is_retryable || attempt == 3 {
|
||||||
|
last_error = Some(e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 指数退避
|
||||||
|
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
|
||||||
|
}
|
||||||
|
|
||||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||||
let url = format!("{}/chat/completions", self.base_url);
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
|
|
||||||
@@ -204,20 +299,59 @@ impl DeepSeekClient {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
// 思考模式下,temperature/top_p 等参数不应传递
|
||||||
model: self.model.clone(),
|
// 非思考模式下可以正常传递
|
||||||
messages,
|
let (temperature, max_tokens, top_p, presence_penalty, frequency_penalty) =
|
||||||
max_tokens: Some(500),
|
if self.thinking_enabled {
|
||||||
temperature: Some(0.7),
|
(None, Some(self.max_tokens), None, None, None)
|
||||||
stream: false,
|
} else {
|
||||||
thinking,
|
(
|
||||||
|
Some(self.temperature),
|
||||||
|
Some(self.max_tokens),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let reasoning_effort = if self.thinking_enabled {
|
||||||
|
self.reasoning_effort.clone()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = self.client
|
let request = ChatCompletionRequest {
|
||||||
.post(&url)
|
model: self.model.clone(),
|
||||||
|
messages: messages.clone(),
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
presence_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
stream: self.thinking_enabled,
|
||||||
|
thinking,
|
||||||
|
reasoning_effort,
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.thinking_enabled {
|
||||||
|
self.streaming_chat_completion(&url, &request).await
|
||||||
|
} else {
|
||||||
|
self.non_streaming_chat_completion(&url, &request).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 非流式请求(非思考模式)
|
||||||
|
async fn non_streaming_chat_completion(
|
||||||
|
&self,
|
||||||
|
url: &str,
|
||||||
|
request: &ChatCompletionRequest,
|
||||||
|
) -> Result<String> {
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(url)
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.json(&request)
|
.json(request)
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.context("Failed to send request to DeepSeek")?;
|
.context("Failed to send request to DeepSeek")?;
|
||||||
@@ -228,7 +362,11 @@ impl DeepSeekClient {
|
|||||||
let text = response.text().await.unwrap_or_default();
|
let text = response.text().await.unwrap_or_default();
|
||||||
|
|
||||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||||
bail!("DeepSeek API error: {} ({})", error.error.message, error.error.error_type);
|
bail!(
|
||||||
|
"DeepSeek API error: {} ({})",
|
||||||
|
error.error.message,
|
||||||
|
error.error.error_type
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
bail!("DeepSeek API error: {} - {}", status, text);
|
bail!("DeepSeek API error: {} - {}", status, text);
|
||||||
@@ -239,34 +377,165 @@ impl DeepSeekClient {
|
|||||||
.await
|
.await
|
||||||
.context("Failed to parse DeepSeek response")?;
|
.context("Failed to parse DeepSeek response")?;
|
||||||
|
|
||||||
result.choices
|
result
|
||||||
|
.choices
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.next()
|
.next()
|
||||||
.map(|c| {
|
.map(|c| c.message.content.trim().to_string())
|
||||||
let content = c.message.content.trim().to_string();
|
|
||||||
if content.is_empty() {
|
|
||||||
c.reasoning_content
|
|
||||||
.map(|r| r.trim().to_string())
|
|
||||||
.unwrap_or_default()
|
|
||||||
} else {
|
|
||||||
content
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.filter(|s| !s.is_empty())
|
.filter(|s| !s.is_empty())
|
||||||
.ok_or_else(|| anyhow::anyhow!(
|
.ok_or_else(|| anyhow::anyhow!("No response from DeepSeek"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 流式请求(思考模式),处理 reasoning_content 和 content
|
||||||
|
async fn streaming_chat_completion(
|
||||||
|
&self,
|
||||||
|
url: &str,
|
||||||
|
request: &ChatCompletionRequest,
|
||||||
|
) -> Result<String> {
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(url)
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Accept", "text/event-stream")
|
||||||
|
.json(request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.context("Failed to send streaming request to DeepSeek")?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let text = response.text().await.unwrap_or_default();
|
||||||
|
|
||||||
|
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||||
|
bail!(
|
||||||
|
"DeepSeek API error: {} ({})",
|
||||||
|
error.error.message,
|
||||||
|
error.error.error_type
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
bail!("DeepSeek API error: {} - {}", status, text);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut content_buffer = String::new();
|
||||||
|
let mut has_reasoning = false;
|
||||||
|
let mut has_content = false;
|
||||||
|
let mut stream_ended = false;
|
||||||
|
|
||||||
|
let thinking_state = self.thinking_state.as_ref();
|
||||||
|
|
||||||
|
let mut byte_stream = response.bytes_stream();
|
||||||
|
let mut line_buffer = String::new();
|
||||||
|
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
|
||||||
|
while let Some(chunk) = byte_stream.next().await {
|
||||||
|
let chunk = chunk.context("Failed to read streaming response chunk")?;
|
||||||
|
let chunk_str =
|
||||||
|
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
|
||||||
|
|
||||||
|
line_buffer.push_str(&chunk_str);
|
||||||
|
|
||||||
|
// 处理完整行
|
||||||
|
while let Some(line_end) = line_buffer.find('\n') {
|
||||||
|
let line = line_buffer[..line_end].trim().to_string();
|
||||||
|
line_buffer = line_buffer[line_end + 1..].to_string();
|
||||||
|
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE 格式:data: {...} 或 data: [DONE]
|
||||||
|
if line == "data: [DONE]" {
|
||||||
|
stream_ended = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(json_str) = line.strip_prefix("data: ") {
|
||||||
|
match serde_json::from_str::<StreamChunk>(json_str) {
|
||||||
|
Ok(chunk) => {
|
||||||
|
for choice in &chunk.choices {
|
||||||
|
// 处理 reasoning_content
|
||||||
|
if let Some(ref reasoning) = choice.delta.reasoning_content
|
||||||
|
&& !reasoning.is_empty() {
|
||||||
|
if !has_reasoning {
|
||||||
|
has_reasoning = true;
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
state.start_thinking();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// reasoning_content 不对外输出,仅用于内部状态判断
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 content
|
||||||
|
if let Some(ref content) = choice.delta.content
|
||||||
|
&& !content.is_empty() {
|
||||||
|
// reasoning 结束,content 开始出现时移除 thinking 标识
|
||||||
|
if has_reasoning && !has_content
|
||||||
|
&& let Some(state) = thinking_state {
|
||||||
|
state.end_thinking();
|
||||||
|
}
|
||||||
|
has_content = true;
|
||||||
|
content_buffer.push_str(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 finish_reason
|
||||||
|
if let Some(ref reason) = choice.finish_reason
|
||||||
|
&& reason == "stop" {
|
||||||
|
stream_ended = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// 忽略无法解析的行(可能是心跳或注释)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if stream_ended {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保思考状态已结束
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
state.end_thinking();
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = content_buffer.trim().to_string();
|
||||||
|
|
||||||
|
if result.is_empty() {
|
||||||
|
if has_reasoning && !has_content {
|
||||||
|
bail!(
|
||||||
|
"DeepSeek returned reasoning content but no final answer. \
|
||||||
|
The model may have entered an incomplete thinking state. \
|
||||||
|
Please try again or disable thinking mode."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
bail!(
|
||||||
"No response from DeepSeek. \
|
"No response from DeepSeek. \
|
||||||
If thinking mode is enabled, try disabling it or ensure the model supports it."
|
If thinking mode is enabled, try disabling it or ensure the model supports it."
|
||||||
))
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Available DeepSeek models
|
/// 可用 DeepSeek 模型列表
|
||||||
|
/// deepseek-chat / deepseek-reasoner 将于 2026-07-24 停用,推荐使用 V4 系列
|
||||||
pub const DEEPSEEK_MODELS: &[&str] = &[
|
pub const DEEPSEEK_MODELS: &[&str] = &[
|
||||||
|
"deepseek-v4-flash",
|
||||||
|
"deepseek-v4-pro",
|
||||||
|
// 兼容旧版模型 ID(将于 2026-07-24 停用)
|
||||||
"deepseek-chat",
|
"deepseek-chat",
|
||||||
"deepseek-reasoner",
|
"deepseek-reasoner",
|
||||||
];
|
];
|
||||||
|
|
||||||
/// Check if a model name is valid
|
|
||||||
pub fn is_valid_model(model: &str) -> bool {
|
pub fn is_valid_model(model: &str) -> bool {
|
||||||
DEEPSEEK_MODELS.contains(&model)
|
DEEPSEEK_MODELS.contains(&model)
|
||||||
}
|
}
|
||||||
@@ -276,9 +545,76 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_model_validation() {
|
fn test_model_validation_v4() {
|
||||||
|
assert!(is_valid_model("deepseek-v4-flash"));
|
||||||
|
assert!(is_valid_model("deepseek-v4-pro"));
|
||||||
assert!(is_valid_model("deepseek-chat"));
|
assert!(is_valid_model("deepseek-chat"));
|
||||||
assert!(is_valid_model("deepseek-reasoner"));
|
assert!(is_valid_model("deepseek-reasoner"));
|
||||||
assert!(!is_valid_model("invalid-model"));
|
assert!(!is_valid_model("invalid-model"));
|
||||||
|
assert!(!is_valid_model("deepseek-v3"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_builder_defaults() {
|
||||||
|
let client = DeepSeekClient::new("test-key", "deepseek-v4-flash").unwrap();
|
||||||
|
assert!(!client.thinking_enabled);
|
||||||
|
assert_eq!(client.max_tokens, 500);
|
||||||
|
assert_eq!(client.temperature, 0.7);
|
||||||
|
assert!(client.reasoning_effort.is_none());
|
||||||
|
assert!(client.thinking_state.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_builder_with_thinking() {
|
||||||
|
let client = DeepSeekClient::new("test-key", "deepseek-v4-flash")
|
||||||
|
.unwrap()
|
||||||
|
.with_thinking(true)
|
||||||
|
.with_reasoning_effort(Some("high".to_string()))
|
||||||
|
.with_max_tokens(1000)
|
||||||
|
.with_temperature(0.5);
|
||||||
|
|
||||||
|
assert!(client.thinking_enabled);
|
||||||
|
assert_eq!(client.reasoning_effort, Some("high".to_string()));
|
||||||
|
assert_eq!(client.max_tokens, 1000);
|
||||||
|
assert_eq!(client.temperature, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_thinking_config_serialization() {
|
||||||
|
let config = ThinkingConfig {
|
||||||
|
thinking_type: "enabled".to_string(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&config).unwrap();
|
||||||
|
assert_eq!(json, r#"{"type":"enabled"}"#);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_serialization_without_reasoning() {
|
||||||
|
let msg = Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hello".to_string(),
|
||||||
|
reasoning_content: None,
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
assert!(!json.contains("reasoning_content"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_delta_parsing() {
|
||||||
|
let json = r#"{"content":"Hello","reasoning_content":null}"#;
|
||||||
|
let delta: StreamDelta = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(delta.content, Some("Hello".to_string()));
|
||||||
|
assert!(delta.reasoning_content.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_delta_reasoning_only() {
|
||||||
|
let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#;
|
||||||
|
let delta: StreamDelta = serde_json::from_str(json).unwrap();
|
||||||
|
assert!(delta.content.is_none());
|
||||||
|
assert_eq!(
|
||||||
|
delta.reasoning_content,
|
||||||
|
Some("Let me think...".to_string())
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
855
src/llm/kimi.rs
855
src/llm/kimi.rs
@@ -1,284 +1,571 @@
|
|||||||
use super::{create_http_client, LlmProvider};
|
use super::thinking::ThinkingStateManager;
|
||||||
use anyhow::{bail, Context, Result};
|
use super::{create_http_client, LlmProvider};
|
||||||
use async_trait::async_trait;
|
use anyhow::{bail, Context, Result};
|
||||||
use serde::{Deserialize, Serialize};
|
use async_trait::async_trait;
|
||||||
use std::time::Duration;
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
/// Kimi API client (Moonshot AI)
|
use std::time::Duration;
|
||||||
pub struct KimiClient {
|
|
||||||
base_url: String,
|
/// Kimi API client (Moonshot AI)
|
||||||
api_key: String,
|
pub struct KimiClient {
|
||||||
model: String,
|
base_url: String,
|
||||||
client: reqwest::Client,
|
api_key: String,
|
||||||
thinking_enabled: bool,
|
model: String,
|
||||||
}
|
client: reqwest::Client,
|
||||||
|
thinking_enabled: bool,
|
||||||
#[derive(Debug, Serialize)]
|
max_tokens: u32,
|
||||||
struct ChatCompletionRequest {
|
temperature: f32,
|
||||||
model: String,
|
thinking_state: Option<Arc<ThinkingStateManager>>,
|
||||||
messages: Vec<Message>,
|
}
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
max_tokens: Option<u32>,
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
struct ChatCompletionRequest {
|
||||||
temperature: Option<f32>,
|
model: String,
|
||||||
stream: bool,
|
messages: Vec<Message>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
thinking: Option<ThinkingConfig>,
|
max_tokens: Option<u32>,
|
||||||
}
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
temperature: Option<f32>,
|
||||||
#[derive(Debug, Serialize)]
|
stream: bool,
|
||||||
struct ThinkingConfig {
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
#[serde(rename = "type")]
|
thinking: Option<ThinkingConfig>,
|
||||||
thinking_type: String,
|
}
|
||||||
}
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
struct ThinkingConfig {
|
||||||
struct Message {
|
#[serde(rename = "type")]
|
||||||
role: String,
|
thinking_type: String,
|
||||||
content: String,
|
}
|
||||||
}
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[derive(Debug, Deserialize)]
|
struct Message {
|
||||||
struct ChatCompletionResponse {
|
role: String,
|
||||||
choices: Vec<Choice>,
|
content: String,
|
||||||
}
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
reasoning_content: Option<String>,
|
||||||
#[derive(Debug, Deserialize)]
|
}
|
||||||
struct Choice {
|
|
||||||
message: Message,
|
#[derive(Debug, Deserialize)]
|
||||||
#[serde(default)]
|
struct ChatCompletionResponse {
|
||||||
reasoning_content: Option<String>,
|
choices: Vec<Choice>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ErrorResponse {
|
struct Choice {
|
||||||
error: ApiError,
|
message: Message,
|
||||||
}
|
#[serde(default)]
|
||||||
|
reasoning_content: Option<String>,
|
||||||
#[derive(Debug, Deserialize)]
|
}
|
||||||
struct ApiError {
|
|
||||||
message: String,
|
// --- Streaming response structures ---
|
||||||
#[serde(rename = "type")]
|
|
||||||
error_type: String,
|
#[derive(Debug, Deserialize)]
|
||||||
}
|
struct StreamChunk {
|
||||||
|
choices: Vec<StreamChoice>,
|
||||||
impl KimiClient {
|
}
|
||||||
/// Create new Kimi client
|
|
||||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
#[derive(Debug, Deserialize)]
|
||||||
let client = create_http_client(Duration::from_secs(60))?;
|
struct StreamChoice {
|
||||||
|
delta: StreamDelta,
|
||||||
Ok(Self {
|
#[serde(default)]
|
||||||
base_url: "https://api.moonshot.cn/v1".to_string(),
|
finish_reason: Option<String>,
|
||||||
api_key: api_key.to_string(),
|
index: Option<u32>,
|
||||||
model: model.to_string(),
|
}
|
||||||
client,
|
|
||||||
thinking_enabled: false,
|
#[derive(Debug, Deserialize, Default)]
|
||||||
})
|
struct StreamDelta {
|
||||||
}
|
#[serde(default)]
|
||||||
|
content: Option<String>,
|
||||||
/// Create with custom base URL
|
#[serde(default)]
|
||||||
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
reasoning_content: Option<String>,
|
||||||
let client = create_http_client(Duration::from_secs(60))?;
|
}
|
||||||
|
|
||||||
Ok(Self {
|
#[derive(Debug, Deserialize)]
|
||||||
base_url: base_url.trim_end_matches('/').to_string(),
|
struct ErrorResponse {
|
||||||
api_key: api_key.to_string(),
|
error: ApiError,
|
||||||
model: model.to_string(),
|
}
|
||||||
client,
|
|
||||||
thinking_enabled: false,
|
#[derive(Debug, Deserialize)]
|
||||||
})
|
struct ApiError {
|
||||||
}
|
message: String,
|
||||||
|
#[serde(rename = "type")]
|
||||||
/// Set timeout
|
error_type: String,
|
||||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
}
|
||||||
self.client = create_http_client(timeout)?;
|
|
||||||
Ok(self)
|
impl KimiClient {
|
||||||
}
|
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||||
|
let client = create_http_client(Duration::from_secs(300))?;
|
||||||
/// Enable or disable thinking mode
|
|
||||||
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
Ok(Self {
|
||||||
self.thinking_enabled = enabled;
|
base_url: "https://api.moonshot.cn/v1".to_string(),
|
||||||
self
|
api_key: api_key.to_string(),
|
||||||
}
|
model: model.to_string(),
|
||||||
|
client,
|
||||||
/// List available models
|
thinking_enabled: false,
|
||||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
max_tokens: 500,
|
||||||
let url = format!("{}/models", self.base_url);
|
temperature: 1.0,
|
||||||
|
thinking_state: None,
|
||||||
let response = self.client
|
})
|
||||||
.get(&url)
|
}
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
||||||
.send()
|
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
||||||
.await
|
let client = create_http_client(Duration::from_secs(300))?;
|
||||||
.context("Failed to list Kimi models")?;
|
|
||||||
|
Ok(Self {
|
||||||
if !response.status().is_success() {
|
base_url: base_url.trim_end_matches('/').to_string(),
|
||||||
let status = response.status();
|
api_key: api_key.to_string(),
|
||||||
let text = response.text().await.unwrap_or_default();
|
model: model.to_string(),
|
||||||
bail!("Kimi API error: {} - {}", status, text);
|
client,
|
||||||
}
|
thinking_enabled: false,
|
||||||
|
max_tokens: 500,
|
||||||
#[derive(Deserialize)]
|
temperature: 1.0,
|
||||||
struct ModelsResponse {
|
thinking_state: None,
|
||||||
data: Vec<Model>,
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||||
struct Model {
|
self.client = create_http_client(timeout)?;
|
||||||
id: String,
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
let result: ModelsResponse = response
|
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
||||||
.json()
|
self.thinking_enabled = enabled;
|
||||||
.await
|
self
|
||||||
.context("Failed to parse Kimi response")?;
|
}
|
||||||
|
|
||||||
Ok(result.data.into_iter().map(|m| m.id).collect())
|
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||||
}
|
self.max_tokens = max_tokens;
|
||||||
|
self
|
||||||
/// Validate API key
|
}
|
||||||
pub async fn validate_key(&self) -> Result<bool> {
|
|
||||||
match self.list_models().await {
|
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||||
Ok(_) => Ok(true),
|
self.temperature = temperature;
|
||||||
Err(e) => {
|
self
|
||||||
let err_str = e.to_string();
|
}
|
||||||
if err_str.contains("401") || err_str.contains("Unauthorized") {
|
|
||||||
Ok(false)
|
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
|
||||||
} else {
|
self.thinking_state = Some(state);
|
||||||
Err(e)
|
self
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||||
}
|
let url = format!("{}/models", self.base_url);
|
||||||
}
|
|
||||||
|
let response = self
|
||||||
#[async_trait]
|
.client
|
||||||
impl LlmProvider for KimiClient {
|
.get(&url)
|
||||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
let messages = vec![
|
.send()
|
||||||
Message {
|
.await
|
||||||
role: "user".to_string(),
|
.context("Failed to list Kimi models")?;
|
||||||
content: prompt.to_string(),
|
|
||||||
},
|
if !response.status().is_success() {
|
||||||
];
|
let status = response.status();
|
||||||
|
let text = response.text().await.unwrap_or_default();
|
||||||
self.chat_completion(messages).await
|
bail!("Kimi API error: {} - {}", status, text);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
#[derive(Deserialize)]
|
||||||
let mut messages = vec![];
|
struct ModelsResponse {
|
||||||
|
data: Vec<ModelId>,
|
||||||
if !system.is_empty() {
|
}
|
||||||
messages.push(Message {
|
|
||||||
role: "system".to_string(),
|
#[derive(Deserialize)]
|
||||||
content: system.to_string(),
|
struct ModelId {
|
||||||
});
|
id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.push(Message {
|
let result: ModelsResponse = response
|
||||||
role: "user".to_string(),
|
.json()
|
||||||
content: user.to_string(),
|
.await
|
||||||
});
|
.context("Failed to parse Kimi response")?;
|
||||||
|
|
||||||
self.chat_completion(messages).await
|
Ok(result.data.into_iter().map(|m| m.id).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn is_available(&self) -> bool {
|
pub async fn validate_key(&self) -> Result<bool> {
|
||||||
self.validate_key().await.unwrap_or(false)
|
match self.list_models().await {
|
||||||
}
|
Ok(_) => Ok(true),
|
||||||
|
Err(e) => {
|
||||||
fn name(&self) -> &str {
|
let err_str = e.to_string();
|
||||||
"kimi"
|
if err_str.contains("401") || err_str.contains("Unauthorized") {
|
||||||
}
|
Ok(false)
|
||||||
}
|
} else {
|
||||||
|
Err(e)
|
||||||
impl KimiClient {
|
}
|
||||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
}
|
||||||
let url = format!("{}/chat/completions", self.base_url);
|
}
|
||||||
|
}
|
||||||
let thinking = if self.thinking_enabled {
|
}
|
||||||
Some(ThinkingConfig {
|
|
||||||
thinking_type: "enabled".to_string(),
|
#[async_trait]
|
||||||
})
|
impl LlmProvider for KimiClient {
|
||||||
} else {
|
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||||
None
|
let messages = vec![Message {
|
||||||
};
|
role: "user".to_string(),
|
||||||
|
content: prompt.to_string(),
|
||||||
let request = ChatCompletionRequest {
|
reasoning_content: None,
|
||||||
model: self.model.clone(),
|
}];
|
||||||
messages,
|
|
||||||
max_tokens: Some(500),
|
self.chat_completion_with_retry(messages).await
|
||||||
temperature: Some(1.0),
|
}
|
||||||
stream: false,
|
|
||||||
thinking,
|
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||||
};
|
let mut messages = vec![];
|
||||||
|
|
||||||
let response = self.client
|
if !system.is_empty() {
|
||||||
.post(&url)
|
messages.push(Message {
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
role: "system".to_string(),
|
||||||
.header("Content-Type", "application/json")
|
content: system.to_string(),
|
||||||
.json(&request)
|
reasoning_content: None,
|
||||||
.send()
|
});
|
||||||
.await
|
}
|
||||||
.context("Failed to send request to Kimi")?;
|
|
||||||
|
messages.push(Message {
|
||||||
let status = response.status();
|
role: "user".to_string(),
|
||||||
|
content: user.to_string(),
|
||||||
if !status.is_success() {
|
reasoning_content: None,
|
||||||
let text = response.text().await.unwrap_or_default();
|
});
|
||||||
|
|
||||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
self.chat_completion_with_retry(messages).await
|
||||||
bail!("Kimi API error: {} ({})", error.error.message, error.error.error_type);
|
}
|
||||||
}
|
|
||||||
|
async fn is_available(&self) -> bool {
|
||||||
bail!("Kimi API error: {} - {}", status, text);
|
self.validate_key().await.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
let result: ChatCompletionResponse = response
|
fn name(&self) -> &str {
|
||||||
.json()
|
"kimi"
|
||||||
.await
|
}
|
||||||
.context("Failed to parse Kimi response")?;
|
}
|
||||||
|
|
||||||
result.choices
|
impl KimiClient {
|
||||||
.into_iter()
|
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
|
||||||
.next()
|
let mut last_error = None;
|
||||||
.map(|c| {
|
|
||||||
let content = c.message.content.trim().to_string();
|
for attempt in 1..=3 {
|
||||||
if content.is_empty() {
|
match self.chat_completion(messages.clone()).await {
|
||||||
c.reasoning_content
|
Ok(result) => return Ok(result),
|
||||||
.map(|r| r.trim().to_string())
|
Err(e) => {
|
||||||
.unwrap_or_default()
|
let err_msg = e.to_string();
|
||||||
} else {
|
let is_retryable = err_msg.contains("timeout")
|
||||||
content
|
|| err_msg.contains("connection")
|
||||||
}
|
|| err_msg.contains("temporary")
|
||||||
})
|
|| err_msg.contains("5")
|
||||||
.filter(|s| !s.is_empty())
|
&& (err_msg.contains("500")
|
||||||
.ok_or_else(|| anyhow::anyhow!(
|
|| err_msg.contains("502")
|
||||||
"No response from Kimi. \
|
|| err_msg.contains("503")
|
||||||
If thinking mode is enabled, try disabling it or ensure the model supports it."
|
|| err_msg.contains("504"));
|
||||||
))
|
|
||||||
}
|
if !is_retryable || attempt == 3 {
|
||||||
}
|
last_error = Some(e);
|
||||||
|
break;
|
||||||
/// Available Kimi models
|
}
|
||||||
pub const KIMI_MODELS: &[&str] = &[
|
|
||||||
"moonshot-v1-8k",
|
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
|
||||||
"moonshot-v1-32k",
|
}
|
||||||
"moonshot-v1-128k",
|
}
|
||||||
];
|
}
|
||||||
|
|
||||||
/// Check if a model name is valid
|
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
|
||||||
pub fn is_valid_model(model: &str) -> bool {
|
}
|
||||||
KIMI_MODELS.contains(&model)
|
|
||||||
}
|
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||||
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
let thinking = if self.thinking_enabled {
|
||||||
use super::*;
|
Some(ThinkingConfig {
|
||||||
|
thinking_type: "enabled".to_string(),
|
||||||
#[test]
|
})
|
||||||
fn test_model_validation() {
|
} else {
|
||||||
assert!(is_valid_model("moonshot-v1-8k"));
|
None
|
||||||
assert!(!is_valid_model("invalid-model"));
|
};
|
||||||
}
|
|
||||||
}
|
// 对于 kimi-k2.6 等支持思考模式的模型,使用默认 temperature 即可
|
||||||
|
// 思考模式下不显式指定 temperature
|
||||||
|
let temperature = if self.thinking_enabled {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.temperature)
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
model: self.model.clone(),
|
||||||
|
messages: messages.clone(),
|
||||||
|
max_tokens: Some(self.max_tokens),
|
||||||
|
temperature,
|
||||||
|
stream: self.thinking_enabled,
|
||||||
|
thinking,
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.thinking_enabled {
|
||||||
|
self.streaming_chat_completion(&url, &request).await
|
||||||
|
} else {
|
||||||
|
self.non_streaming_chat_completion(&url, &request).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 非流式请求(非思考模式)
|
||||||
|
async fn non_streaming_chat_completion(
|
||||||
|
&self,
|
||||||
|
url: &str,
|
||||||
|
request: &ChatCompletionRequest,
|
||||||
|
) -> Result<String> {
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(url)
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.json(request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.context("Failed to send request to Kimi")?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let text = response.text().await.unwrap_or_default();
|
||||||
|
|
||||||
|
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||||
|
bail!(
|
||||||
|
"Kimi API error: {} ({})",
|
||||||
|
error.error.message,
|
||||||
|
error.error.error_type
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
bail!("Kimi API error: {} - {}", status, text);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: ChatCompletionResponse = response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.context("Failed to parse Kimi response")?;
|
||||||
|
|
||||||
|
result
|
||||||
|
.choices
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.map(|c| c.message.content.trim().to_string())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No response from Kimi"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 流式请求(思考模式),处理 reasoning_content 和 content
|
||||||
|
async fn streaming_chat_completion(
|
||||||
|
&self,
|
||||||
|
url: &str,
|
||||||
|
request: &ChatCompletionRequest,
|
||||||
|
) -> Result<String> {
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(url)
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Accept", "text/event-stream")
|
||||||
|
.json(request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.context("Failed to send streaming request to Kimi")?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let text = response.text().await.unwrap_or_default();
|
||||||
|
|
||||||
|
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||||
|
bail!(
|
||||||
|
"Kimi API error: {} ({})",
|
||||||
|
error.error.message,
|
||||||
|
error.error.error_type
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
bail!("Kimi API error: {} - {}", status, text);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut content_buffer = String::new();
|
||||||
|
let mut has_reasoning = false;
|
||||||
|
let mut has_content = false;
|
||||||
|
let mut stream_ended = false;
|
||||||
|
|
||||||
|
let thinking_state = self.thinking_state.as_ref();
|
||||||
|
|
||||||
|
let mut byte_stream = response.bytes_stream();
|
||||||
|
let mut line_buffer = String::new();
|
||||||
|
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
|
||||||
|
while let Some(chunk) = byte_stream.next().await {
|
||||||
|
let chunk = chunk.context("Failed to read streaming response chunk")?;
|
||||||
|
let chunk_str =
|
||||||
|
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
|
||||||
|
|
||||||
|
line_buffer.push_str(&chunk_str);
|
||||||
|
|
||||||
|
while let Some(line_end) = line_buffer.find('\n') {
|
||||||
|
let line = line_buffer[..line_end].trim().to_string();
|
||||||
|
line_buffer = line_buffer[line_end + 1..].to_string();
|
||||||
|
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if line == "data: [DONE]" {
|
||||||
|
stream_ended = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(json_str) = line.strip_prefix("data: ") {
|
||||||
|
match serde_json::from_str::<StreamChunk>(json_str) {
|
||||||
|
Ok(chunk) => {
|
||||||
|
for choice in &chunk.choices {
|
||||||
|
if let Some(ref reasoning) = choice.delta.reasoning_content
|
||||||
|
&& !reasoning.is_empty() {
|
||||||
|
if !has_reasoning {
|
||||||
|
has_reasoning = true;
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
state.start_thinking();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref content) = choice.delta.content
|
||||||
|
&& !content.is_empty() {
|
||||||
|
if has_reasoning && !has_content
|
||||||
|
&& let Some(state) = thinking_state {
|
||||||
|
state.end_thinking();
|
||||||
|
}
|
||||||
|
has_content = true;
|
||||||
|
content_buffer.push_str(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref reason) = choice.finish_reason
|
||||||
|
&& reason == "stop" {
|
||||||
|
stream_ended = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// 忽略无法解析的行
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if stream_ended {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保思考状态已结束
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
state.end_thinking();
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = content_buffer.trim().to_string();
|
||||||
|
|
||||||
|
if result.is_empty() {
|
||||||
|
if has_reasoning && !has_content {
|
||||||
|
bail!(
|
||||||
|
"Kimi returned reasoning content but no final answer. \
|
||||||
|
The model may have entered an incomplete thinking state. \
|
||||||
|
Please try again or disable thinking mode."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
bail!(
|
||||||
|
"No response from Kimi. \
|
||||||
|
If thinking mode is enabled, try disabling it or ensure the model supports it."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 可用 Kimi 模型列表
|
||||||
|
pub const KIMI_MODELS: &[&str] = &[
|
||||||
|
// K2 系列(推荐)
|
||||||
|
"kimi-k2.6",
|
||||||
|
"kimi-k2.5",
|
||||||
|
"kimi-k2-thinking",
|
||||||
|
"kimi-k2-thinking-turbo",
|
||||||
|
"kimi-k2-instruct",
|
||||||
|
"kimi-k2-instruct-0905",
|
||||||
|
// 兼容旧版模型 ID
|
||||||
|
"moonshot-v1-8k",
|
||||||
|
"moonshot-v1-32k",
|
||||||
|
"moonshot-v1-128k",
|
||||||
|
];
|
||||||
|
|
||||||
|
pub fn is_valid_model(model: &str) -> bool {
|
||||||
|
KIMI_MODELS.contains(&model)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_validation_k2() {
|
||||||
|
assert!(is_valid_model("kimi-k2.6"));
|
||||||
|
assert!(is_valid_model("kimi-k2.5"));
|
||||||
|
assert!(is_valid_model("kimi-k2-thinking"));
|
||||||
|
assert!(is_valid_model("kimi-k2-thinking-turbo"));
|
||||||
|
assert!(is_valid_model("moonshot-v1-8k"));
|
||||||
|
assert!(is_valid_model("moonshot-v1-32k"));
|
||||||
|
assert!(is_valid_model("moonshot-v1-128k"));
|
||||||
|
assert!(!is_valid_model("invalid-model"));
|
||||||
|
assert!(!is_valid_model("kimi-k1.5"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_builder_defaults() {
|
||||||
|
let client = KimiClient::new("test-key", "kimi-k2.6").unwrap();
|
||||||
|
assert!(!client.thinking_enabled);
|
||||||
|
assert_eq!(client.max_tokens, 500);
|
||||||
|
assert_eq!(client.temperature, 1.0);
|
||||||
|
assert!(client.thinking_state.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_builder_with_thinking() {
|
||||||
|
let client = KimiClient::new("test-key", "kimi-k2.6")
|
||||||
|
.unwrap()
|
||||||
|
.with_thinking(true)
|
||||||
|
.with_max_tokens(1000)
|
||||||
|
.with_temperature(0.5);
|
||||||
|
|
||||||
|
assert!(client.thinking_enabled);
|
||||||
|
assert_eq!(client.max_tokens, 1000);
|
||||||
|
assert_eq!(client.temperature, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_thinking_config_serialization() {
|
||||||
|
let config = ThinkingConfig {
|
||||||
|
thinking_type: "enabled".to_string(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&config).unwrap();
|
||||||
|
assert_eq!(json, r#"{"type":"enabled"}"#);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_new_defaults() {
|
||||||
|
let client = KimiClient::new("test-key", "kimi-k2.6").unwrap();
|
||||||
|
assert_eq!(client.name(), "kimi");
|
||||||
|
assert!(!client.thinking_enabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_serialization() {
|
||||||
|
let msg = Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hello".to_string(),
|
||||||
|
reasoning_content: None,
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&msg).unwrap();
|
||||||
|
assert!(!json.contains("reasoning_content"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ pub mod anthropic;
|
|||||||
pub mod kimi;
|
pub mod kimi;
|
||||||
pub mod deepseek;
|
pub mod deepseek;
|
||||||
pub mod openrouter;
|
pub mod openrouter;
|
||||||
|
pub mod thinking;
|
||||||
|
|
||||||
pub use ollama::OllamaClient;
|
pub use ollama::OllamaClient;
|
||||||
pub use openai::OpenAiClient;
|
pub use openai::OpenAiClient;
|
||||||
@@ -61,48 +62,114 @@ impl Default for LlmClientConfig {
|
|||||||
impl LlmClient {
|
impl LlmClient {
|
||||||
/// Create LLM client from configuration manager
|
/// Create LLM client from configuration manager
|
||||||
pub async fn from_config(manager: &crate::config::manager::ConfigManager) -> Result<Self> {
|
pub async fn from_config(manager: &crate::config::manager::ConfigManager) -> Result<Self> {
|
||||||
|
Self::from_config_with_think(manager, manager.config().llm.thinking_enabled).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create LLM client from configuration with explicit thinking override
|
||||||
|
pub async fn from_config_with_think(
|
||||||
|
manager: &crate::config::manager::ConfigManager,
|
||||||
|
thinking_enabled: bool,
|
||||||
|
) -> Result<Self> {
|
||||||
let config = manager.config();
|
let config = manager.config();
|
||||||
let client_config = LlmClientConfig {
|
let client_config = LlmClientConfig {
|
||||||
max_tokens: config.llm.max_tokens,
|
max_tokens: config.llm.max_tokens,
|
||||||
temperature: config.llm.temperature,
|
temperature: config.llm.temperature,
|
||||||
timeout: Duration::from_secs(config.llm.timeout),
|
timeout: Duration::from_secs(config.llm.timeout),
|
||||||
thinking_enabled: config.llm.thinking_enabled,
|
thinking_enabled,
|
||||||
};
|
};
|
||||||
|
|
||||||
let provider = config.llm.provider.as_str();
|
let provider = config.llm.provider.as_str();
|
||||||
let model = config.llm.model.as_str();
|
let model = config.llm.model.as_str();
|
||||||
let base_url = manager.llm_base_url();
|
let base_url = manager.llm_base_url();
|
||||||
let api_key = manager.get_api_key();
|
let api_key = manager.get_api_key();
|
||||||
let thinking_enabled = config.llm.thinking_enabled;
|
|
||||||
|
|
||||||
let provider: Box<dyn LlmProvider> = match provider {
|
let provider: Box<dyn LlmProvider> = match provider {
|
||||||
"ollama" => {
|
"ollama" => {
|
||||||
Box::new(OllamaClient::new(&base_url, model))
|
Box::new(OllamaClient::new(&base_url, model)
|
||||||
|
.with_max_tokens(client_config.max_tokens)
|
||||||
|
.with_temperature(client_config.temperature))
|
||||||
}
|
}
|
||||||
"openai" => {
|
"openai" => {
|
||||||
let key = api_key.as_ref()
|
let key = api_key.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("OpenAI API key not configured"))?;
|
.ok_or_else(|| anyhow::anyhow!("OpenAI API key not configured"))?;
|
||||||
Box::new(OpenAiClient::new(&base_url, key, model)?)
|
let thinking_state = if thinking_enabled {
|
||||||
|
Some(thinking::create_console_thinking_state())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut client = OpenAiClient::new(&base_url, key, model)?
|
||||||
|
.with_thinking(thinking_enabled)
|
||||||
|
.with_max_tokens(client_config.max_tokens)
|
||||||
|
.with_temperature(client_config.temperature)
|
||||||
|
.with_timeout(client_config.timeout)?;
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
client = client.with_thinking_state(state);
|
||||||
|
}
|
||||||
|
Box::new(client)
|
||||||
}
|
}
|
||||||
"anthropic" => {
|
"anthropic" => {
|
||||||
let key = api_key.as_ref()
|
let key = api_key.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("Anthropic API key not configured"))?;
|
.ok_or_else(|| anyhow::anyhow!("Anthropic API key not configured"))?;
|
||||||
Box::new(AnthropicClient::new(key, model)?)
|
let thinking_state = if thinking_enabled {
|
||||||
|
Some(thinking::create_console_thinking_state())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let budget = config.llm.thinking_budget_tokens.unwrap_or(1024);
|
||||||
|
let mut client = AnthropicClient::new(key, model)?
|
||||||
|
.with_thinking(thinking_enabled)
|
||||||
|
.with_thinking_budget_tokens(budget)
|
||||||
|
.with_max_tokens(client_config.max_tokens)
|
||||||
|
.with_temperature(client_config.temperature)
|
||||||
|
.with_timeout(client_config.timeout)?;
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
client = client.with_thinking_state(state);
|
||||||
|
}
|
||||||
|
Box::new(client)
|
||||||
}
|
}
|
||||||
"kimi" => {
|
"kimi" => {
|
||||||
let key = api_key.as_ref()
|
let key = api_key.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("Kimi API key not configured"))?;
|
.ok_or_else(|| anyhow::anyhow!("Kimi API key not configured"))?;
|
||||||
Box::new(KimiClient::with_base_url(key, model, &base_url)?.with_thinking(thinking_enabled))
|
let thinking_state = if thinking_enabled {
|
||||||
|
Some(thinking::create_console_thinking_state())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut client = KimiClient::with_base_url(key, model, &base_url)?
|
||||||
|
.with_thinking(thinking_enabled)
|
||||||
|
.with_max_tokens(client_config.max_tokens)
|
||||||
|
.with_temperature(client_config.temperature)
|
||||||
|
.with_timeout(client_config.timeout)?;
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
client = client.with_thinking_state(state);
|
||||||
|
}
|
||||||
|
Box::new(client)
|
||||||
}
|
}
|
||||||
"deepseek" => {
|
"deepseek" => {
|
||||||
let key = api_key.as_ref()
|
let key = api_key.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("DeepSeek API key not configured"))?;
|
.ok_or_else(|| anyhow::anyhow!("DeepSeek API key not configured"))?;
|
||||||
Box::new(DeepSeekClient::with_base_url(key, model, &base_url)?.with_thinking(thinking_enabled))
|
let thinking_state = if thinking_enabled {
|
||||||
|
Some(thinking::create_console_thinking_state())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut client = DeepSeekClient::with_base_url(key, model, &base_url)?
|
||||||
|
.with_thinking(thinking_enabled)
|
||||||
|
.with_max_tokens(client_config.max_tokens)
|
||||||
|
.with_temperature(client_config.temperature)
|
||||||
|
.with_timeout(client_config.timeout)?;
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
client = client.with_thinking_state(state);
|
||||||
|
}
|
||||||
|
Box::new(client)
|
||||||
}
|
}
|
||||||
"openrouter" => {
|
"openrouter" => {
|
||||||
let key = api_key.as_ref()
|
let key = api_key.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not configured"))?;
|
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not configured"))?;
|
||||||
Box::new(OpenRouterClient::with_base_url(key, model, &base_url)?)
|
Box::new(OpenRouterClient::with_base_url(key, model, &base_url)?
|
||||||
|
.with_max_tokens(client_config.max_tokens)
|
||||||
|
.with_temperature(client_config.temperature)
|
||||||
|
.with_timeout(client_config.timeout)?)
|
||||||
}
|
}
|
||||||
_ => bail!("Unknown LLM provider: {}", provider),
|
_ => bail!("Unknown LLM provider: {}", provider),
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ pub struct OllamaClient {
|
|||||||
base_url: String,
|
base_url: String,
|
||||||
model: String,
|
model: String,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
|
max_tokens: u32,
|
||||||
|
temperature: f32,
|
||||||
|
top_p: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
@@ -49,11 +52,14 @@ impl OllamaClient {
|
|||||||
pub fn new(base_url: &str, model: &str) -> Self {
|
pub fn new(base_url: &str, model: &str) -> Self {
|
||||||
let client = create_http_client(Duration::from_secs(120))
|
let client = create_http_client(Duration::from_secs(120))
|
||||||
.expect("Failed to create HTTP client");
|
.expect("Failed to create HTTP client");
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
base_url: base_url.trim_end_matches('/').to_string(),
|
base_url: base_url.trim_end_matches('/').to_string(),
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
client,
|
client,
|
||||||
|
max_tokens: 500,
|
||||||
|
temperature: 0.7,
|
||||||
|
top_p: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,6 +70,21 @@ impl OllamaClient {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||||
|
self.max_tokens = max_tokens;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||||
|
self.temperature = temperature;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||||
|
self.top_p = Some(top_p);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// List available models
|
/// List available models
|
||||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||||
let url = format!("{}/api/tags", self.base_url);
|
let url = format!("{}/api/tags", self.base_url);
|
||||||
@@ -143,8 +164,8 @@ impl LlmProvider for OllamaClient {
|
|||||||
system,
|
system,
|
||||||
stream: false,
|
stream: false,
|
||||||
options: GenerationOptions {
|
options: GenerationOptions {
|
||||||
temperature: Some(0.7),
|
temperature: Some(self.temperature),
|
||||||
num_predict: Some(500),
|
num_predict: Some(self.max_tokens),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,23 @@
|
|||||||
|
use super::thinking::ThinkingStateManager;
|
||||||
use super::{create_http_client, LlmProvider};
|
use super::{create_http_client, LlmProvider};
|
||||||
use anyhow::{bail, Context, Result};
|
use anyhow::{bail, Context, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
/// OpenAI API client
|
/// OpenAI API client with o-series reasoning support
|
||||||
pub struct OpenAiClient {
|
pub struct OpenAiClient {
|
||||||
base_url: String,
|
base_url: String,
|
||||||
api_key: String,
|
api_key: String,
|
||||||
model: String,
|
model: String,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
|
thinking_enabled: bool,
|
||||||
|
reasoning_effort: Option<String>,
|
||||||
|
max_tokens: u32,
|
||||||
|
temperature: f32,
|
||||||
|
top_p: Option<f32>,
|
||||||
|
thinking_state: Option<Arc<ThinkingStateManager>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
@@ -20,10 +28,14 @@ struct ChatCompletionRequest {
|
|||||||
max_tokens: Option<u32>,
|
max_tokens: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
reasoning_effort: Option<String>,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
struct Message {
|
struct Message {
|
||||||
role: String,
|
role: String,
|
||||||
content: String,
|
content: String,
|
||||||
@@ -39,6 +51,28 @@ struct Choice {
|
|||||||
message: Message,
|
message: Message,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Streaming response structures ---
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct StreamChunk {
|
||||||
|
choices: Vec<StreamChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct StreamChoice {
|
||||||
|
delta: StreamDelta,
|
||||||
|
#[serde(default)]
|
||||||
|
finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Default)]
|
||||||
|
struct StreamDelta {
|
||||||
|
#[serde(default)]
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
reasoning_content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ErrorResponse {
|
struct ErrorResponse {
|
||||||
error: ApiError,
|
error: ApiError,
|
||||||
@@ -55,57 +89,91 @@ impl OpenAiClient {
|
|||||||
/// Create new OpenAI client
|
/// Create new OpenAI client
|
||||||
pub fn new(base_url: &str, api_key: &str, model: &str) -> Result<Self> {
|
pub fn new(base_url: &str, api_key: &str, model: &str) -> Result<Self> {
|
||||||
let client = create_http_client(Duration::from_secs(60))?;
|
let client = create_http_client(Duration::from_secs(60))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
base_url: base_url.trim_end_matches('/').to_string(),
|
base_url: base_url.trim_end_matches('/').to_string(),
|
||||||
api_key: api_key.to_string(),
|
api_key: api_key.to_string(),
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
client,
|
client,
|
||||||
|
thinking_enabled: false,
|
||||||
|
reasoning_effort: None,
|
||||||
|
max_tokens: 500,
|
||||||
|
temperature: 0.7,
|
||||||
|
top_p: None,
|
||||||
|
thinking_state: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set timeout
|
|
||||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||||
self.client = create_http_client(timeout)?;
|
self.client = create_http_client(timeout)?;
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// List available models
|
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
||||||
|
self.thinking_enabled = enabled;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_reasoning_effort(mut self, effort: Option<String>) -> Self {
|
||||||
|
self.reasoning_effort = effort;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||||
|
self.max_tokens = max_tokens;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||||
|
self.temperature = temperature;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||||
|
self.top_p = Some(top_p);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
|
||||||
|
self.thinking_state = Some(state);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||||
let url = format!("{}/models", self.base_url);
|
let url = format!("{}/models", self.base_url);
|
||||||
|
|
||||||
let response = self.client
|
let response = self
|
||||||
|
.client
|
||||||
.get(&url)
|
.get(&url)
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.context("Failed to list OpenAI models")?;
|
.context("Failed to list OpenAI models")?;
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let text = response.text().await.unwrap_or_default();
|
let text = response.text().await.unwrap_or_default();
|
||||||
bail!("OpenAI API error: {} - {}", status, text);
|
bail!("OpenAI API error: {} - {}", status, text);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct ModelsResponse {
|
struct ModelsResponse {
|
||||||
data: Vec<Model>,
|
data: Vec<Model>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Model {
|
struct Model {
|
||||||
id: String,
|
id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
let result: ModelsResponse = response
|
let result: ModelsResponse = response
|
||||||
.json()
|
.json()
|
||||||
.await
|
.await
|
||||||
.context("Failed to parse OpenAI response")?;
|
.context("Failed to parse OpenAI response")?;
|
||||||
|
|
||||||
Ok(result.data.into_iter().map(|m| m.id).collect())
|
Ok(result.data.into_iter().map(|m| m.id).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate API key
|
|
||||||
pub async fn validate_key(&self) -> Result<bool> {
|
pub async fn validate_key(&self) -> Result<bool> {
|
||||||
match self.list_models().await {
|
match self.list_models().await {
|
||||||
Ok(_) => Ok(true),
|
Ok(_) => Ok(true),
|
||||||
@@ -124,32 +192,30 @@ impl OpenAiClient {
|
|||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl LlmProvider for OpenAiClient {
|
impl LlmProvider for OpenAiClient {
|
||||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||||
let messages = vec![
|
let messages = vec![Message {
|
||||||
Message {
|
role: "user".to_string(),
|
||||||
role: "user".to_string(),
|
content: prompt.to_string(),
|
||||||
content: prompt.to_string(),
|
}];
|
||||||
},
|
|
||||||
];
|
self.chat_completion_with_retry(messages).await
|
||||||
|
|
||||||
self.chat_completion(messages).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||||
let mut messages = vec![];
|
let mut messages = vec![];
|
||||||
|
|
||||||
if !system.is_empty() {
|
if !system.is_empty() {
|
||||||
messages.push(Message {
|
messages.push(Message {
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: system.to_string(),
|
content: system.to_string(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.push(Message {
|
messages.push(Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: user.to_string(),
|
content: user.to_string(),
|
||||||
});
|
});
|
||||||
|
|
||||||
self.chat_completion(messages).await
|
self.chat_completion_with_retry(messages).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn is_available(&self) -> bool {
|
async fn is_available(&self) -> bool {
|
||||||
@@ -162,18 +228,59 @@ impl LlmProvider for OpenAiClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiClient {
|
impl OpenAiClient {
|
||||||
|
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
|
||||||
|
let mut last_error = None;
|
||||||
|
|
||||||
|
for attempt in 1..=3 {
|
||||||
|
match self.chat_completion(messages.clone()).await {
|
||||||
|
Ok(result) => return Ok(result),
|
||||||
|
Err(e) => {
|
||||||
|
let err_msg = e.to_string();
|
||||||
|
let is_retryable = err_msg.contains("timeout")
|
||||||
|
|| err_msg.contains("connection")
|
||||||
|
|| err_msg.contains("temporary")
|
||||||
|
|| err_msg.contains("5")
|
||||||
|
&& (err_msg.contains("500")
|
||||||
|
|| err_msg.contains("502")
|
||||||
|
|| err_msg.contains("503")
|
||||||
|
|| err_msg.contains("504"));
|
||||||
|
|
||||||
|
if !is_retryable || attempt == 3 {
|
||||||
|
last_error = Some(e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
|
||||||
|
}
|
||||||
|
|
||||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||||
|
if self.thinking_enabled {
|
||||||
|
self.streaming_chat_completion(messages).await
|
||||||
|
} else {
|
||||||
|
self.non_streaming_chat_completion(messages).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn non_streaming_chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||||
let url = format!("{}/chat/completions", self.base_url);
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
model: self.model.clone(),
|
model: self.model.clone(),
|
||||||
messages,
|
messages,
|
||||||
max_tokens: Some(500),
|
max_tokens: Some(self.max_tokens),
|
||||||
temperature: Some(0.7),
|
temperature: Some(self.temperature),
|
||||||
|
top_p: self.top_p,
|
||||||
|
reasoning_effort: None,
|
||||||
stream: false,
|
stream: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = self.client
|
let response = self
|
||||||
|
.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
@@ -181,31 +288,165 @@ impl OpenAiClient {
|
|||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.context("Failed to send request to OpenAI")?;
|
.context("Failed to send request to OpenAI")?;
|
||||||
|
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
|
|
||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
let text = response.text().await.unwrap_or_default();
|
let text = response.text().await.unwrap_or_default();
|
||||||
|
|
||||||
// Try to parse error
|
|
||||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||||
bail!("OpenAI API error: {} ({})", error.error.message, error.error.error_type);
|
bail!(
|
||||||
|
"OpenAI API error: {} ({})",
|
||||||
|
error.error.message,
|
||||||
|
error.error.error_type
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
bail!("OpenAI API error: {} - {}", status, text);
|
bail!("OpenAI API error: {} - {}", status, text);
|
||||||
}
|
}
|
||||||
|
|
||||||
let result: ChatCompletionResponse = response
|
let result: ChatCompletionResponse = response
|
||||||
.json()
|
.json()
|
||||||
.await
|
.await
|
||||||
.context("Failed to parse OpenAI response")?;
|
.context("Failed to parse OpenAI response")?;
|
||||||
|
|
||||||
result.choices
|
result
|
||||||
|
.choices
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.next()
|
.next()
|
||||||
.map(|c| c.message.content.trim().to_string())
|
.map(|c| c.message.content.trim().to_string())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
|
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Streaming request for reasoning mode, filters reasoning_content from output
|
||||||
|
async fn streaming_chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||||
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
|
|
||||||
|
// For reasoning/thinking mode, omit temperature and top_p
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
model: self.model.clone(),
|
||||||
|
messages,
|
||||||
|
max_tokens: Some(self.max_tokens),
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
reasoning_effort: self.reasoning_effort.clone(),
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Accept", "text/event-stream")
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.context("Failed to send streaming request to OpenAI")?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let text = response.text().await.unwrap_or_default();
|
||||||
|
|
||||||
|
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||||
|
bail!(
|
||||||
|
"OpenAI API error: {} ({})",
|
||||||
|
error.error.message,
|
||||||
|
error.error.error_type
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
bail!("OpenAI API error: {} - {}", status, text);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut content_buffer = String::new();
|
||||||
|
let mut has_reasoning = false;
|
||||||
|
let mut has_content = false;
|
||||||
|
|
||||||
|
let thinking_state = self.thinking_state.as_ref();
|
||||||
|
|
||||||
|
let mut byte_stream = response.bytes_stream();
|
||||||
|
let mut line_buffer = String::new();
|
||||||
|
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
|
||||||
|
while let Some(chunk) = byte_stream.next().await {
|
||||||
|
let chunk = chunk.context("Failed to read streaming response chunk")?;
|
||||||
|
let chunk_str =
|
||||||
|
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
|
||||||
|
|
||||||
|
line_buffer.push_str(&chunk_str);
|
||||||
|
|
||||||
|
while let Some(line_end) = line_buffer.find('\n') {
|
||||||
|
let line = line_buffer[..line_end].trim().to_string();
|
||||||
|
line_buffer = line_buffer[line_end + 1..].to_string();
|
||||||
|
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if line == "data: [DONE]" {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(json_str) = line.strip_prefix("data: ") {
|
||||||
|
if let Ok(chunk) = serde_json::from_str::<StreamChunk>(json_str) {
|
||||||
|
for choice in &chunk.choices {
|
||||||
|
// Handle reasoning_content (o-series)
|
||||||
|
if let Some(ref reasoning) = choice.delta.reasoning_content
|
||||||
|
&& !reasoning.is_empty()
|
||||||
|
{
|
||||||
|
if !has_reasoning {
|
||||||
|
has_reasoning = true;
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
state.start_thinking();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle content
|
||||||
|
if let Some(ref content) = choice.delta.content
|
||||||
|
&& !content.is_empty()
|
||||||
|
{
|
||||||
|
if has_reasoning && !has_content
|
||||||
|
&& let Some(state) = thinking_state
|
||||||
|
{
|
||||||
|
state.end_thinking();
|
||||||
|
}
|
||||||
|
has_content = true;
|
||||||
|
content_buffer.push_str(content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(state) = thinking_state {
|
||||||
|
state.end_thinking();
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = content_buffer.trim().to_string();
|
||||||
|
|
||||||
|
if result.is_empty() {
|
||||||
|
if has_reasoning && !has_content {
|
||||||
|
bail!(
|
||||||
|
"OpenAI returned reasoning content but no final answer. \
|
||||||
|
The model may have entered an incomplete reasoning state. \
|
||||||
|
Please try again or disable thinking mode."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
bail!(
|
||||||
|
"No response from OpenAI. \
|
||||||
|
If thinking mode is enabled, try disabling it or ensure the model supports reasoning."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Azure OpenAI client (extends OpenAI with Azure-specific config)
|
/// Azure OpenAI client (extends OpenAI with Azure-specific config)
|
||||||
@@ -215,10 +456,15 @@ pub struct AzureOpenAiClient {
|
|||||||
deployment: String,
|
deployment: String,
|
||||||
api_version: String,
|
api_version: String,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
|
thinking_enabled: bool,
|
||||||
|
reasoning_effort: Option<String>,
|
||||||
|
max_tokens: u32,
|
||||||
|
temperature: f32,
|
||||||
|
top_p: Option<f32>,
|
||||||
|
thinking_state: Option<Arc<ThinkingStateManager>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AzureOpenAiClient {
|
impl AzureOpenAiClient {
|
||||||
/// Create new Azure OpenAI client
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
endpoint: &str,
|
endpoint: &str,
|
||||||
api_key: &str,
|
api_key: &str,
|
||||||
@@ -226,13 +472,19 @@ impl AzureOpenAiClient {
|
|||||||
api_version: &str,
|
api_version: &str,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let client = create_http_client(Duration::from_secs(60))?;
|
let client = create_http_client(Duration::from_secs(60))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
endpoint: endpoint.trim_end_matches('/').to_string(),
|
endpoint: endpoint.trim_end_matches('/').to_string(),
|
||||||
api_key: api_key.to_string(),
|
api_key: api_key.to_string(),
|
||||||
deployment: deployment.to_string(),
|
deployment: deployment.to_string(),
|
||||||
api_version: api_version.to_string(),
|
api_version: api_version.to_string(),
|
||||||
client,
|
client,
|
||||||
|
thinking_enabled: false,
|
||||||
|
reasoning_effort: None,
|
||||||
|
max_tokens: 500,
|
||||||
|
temperature: 0.7,
|
||||||
|
top_p: None,
|
||||||
|
thinking_state: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -241,16 +493,19 @@ impl AzureOpenAiClient {
|
|||||||
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
||||||
self.endpoint, self.deployment, self.api_version
|
self.endpoint, self.deployment, self.api_version
|
||||||
);
|
);
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
model: self.deployment.clone(),
|
model: self.deployment.clone(),
|
||||||
messages,
|
messages,
|
||||||
max_tokens: Some(500),
|
max_tokens: Some(self.max_tokens),
|
||||||
temperature: Some(0.7),
|
temperature: Some(self.temperature),
|
||||||
|
top_p: self.top_p,
|
||||||
|
reasoning_effort: self.reasoning_effort.clone(),
|
||||||
stream: false,
|
stream: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = self.client
|
let response = self
|
||||||
|
.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
.header("api-key", &self.api_key)
|
.header("api-key", &self.api_key)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
@@ -258,22 +513,24 @@ impl AzureOpenAiClient {
|
|||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.context("Failed to send request to Azure OpenAI")?;
|
.context("Failed to send request to Azure OpenAI")?;
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let text = response.text().await.unwrap_or_default();
|
let text = response.text().await.unwrap_or_default();
|
||||||
bail!("Azure OpenAI API error: {} - {}", status, text);
|
bail!("Azure OpenAI API error: {} - {}", status, text);
|
||||||
}
|
}
|
||||||
|
|
||||||
let result: ChatCompletionResponse = response
|
let result: ChatCompletionResponse = response
|
||||||
.json()
|
.json()
|
||||||
.await
|
.await
|
||||||
.context("Failed to parse Azure OpenAI response")?;
|
.context("Failed to parse Azure OpenAI response")?;
|
||||||
|
|
||||||
result.choices
|
result
|
||||||
|
.choices
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.next()
|
.next()
|
||||||
.map(|c| c.message.content.trim().to_string())
|
.map(|c| c.message.content.trim().to_string())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))
|
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -281,41 +538,38 @@ impl AzureOpenAiClient {
|
|||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl LlmProvider for AzureOpenAiClient {
|
impl LlmProvider for AzureOpenAiClient {
|
||||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||||
let messages = vec![
|
let messages = vec![Message {
|
||||||
Message {
|
role: "user".to_string(),
|
||||||
role: "user".to_string(),
|
content: prompt.to_string(),
|
||||||
content: prompt.to_string(),
|
}];
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
self.chat_completion(messages).await
|
self.chat_completion(messages).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||||
let mut messages = vec![];
|
let mut messages = vec![];
|
||||||
|
|
||||||
if !system.is_empty() {
|
if !system.is_empty() {
|
||||||
messages.push(Message {
|
messages.push(Message {
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: system.to_string(),
|
content: system.to_string(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.push(Message {
|
messages.push(Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: user.to_string(),
|
content: user.to_string(),
|
||||||
});
|
});
|
||||||
|
|
||||||
self.chat_completion(messages).await
|
self.chat_completion(messages).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn is_available(&self) -> bool {
|
async fn is_available(&self) -> bool {
|
||||||
// Simple check - try to make a minimal request
|
|
||||||
let url = format!(
|
let url = format!(
|
||||||
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
||||||
self.endpoint, self.deployment, self.api_version
|
self.endpoint, self.deployment, self.api_version
|
||||||
);
|
);
|
||||||
|
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
model: self.deployment.clone(),
|
model: self.deployment.clone(),
|
||||||
messages: vec![Message {
|
messages: vec![Message {
|
||||||
@@ -324,10 +578,13 @@ impl LlmProvider for AzureOpenAiClient {
|
|||||||
}],
|
}],
|
||||||
max_tokens: Some(5),
|
max_tokens: Some(5),
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
|
top_p: None,
|
||||||
|
reasoning_effort: None,
|
||||||
stream: false,
|
stream: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
match self.client
|
match self
|
||||||
|
.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
.header("api-key", &self.api_key)
|
.header("api-key", &self.api_key)
|
||||||
.json(&request)
|
.json(&request)
|
||||||
@@ -343,3 +600,59 @@ impl LlmProvider for AzureOpenAiClient {
|
|||||||
"azure-openai"
|
"azure-openai"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Available OpenAI models (including o-series with reasoning)
|
||||||
|
pub const OPENAI_MODELS: &[&str] = &[
|
||||||
|
"o4-mini",
|
||||||
|
"o3",
|
||||||
|
"o3-mini",
|
||||||
|
"o1",
|
||||||
|
"o1-mini",
|
||||||
|
"o1-pro",
|
||||||
|
"gpt-4.1",
|
||||||
|
"gpt-4.1-mini",
|
||||||
|
"gpt-4.1-nano",
|
||||||
|
"gpt-4o",
|
||||||
|
"gpt-4o-mini",
|
||||||
|
"gpt-4-turbo",
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
];
|
||||||
|
|
||||||
|
pub fn is_valid_model(model: &str) -> bool {
|
||||||
|
OPENAI_MODELS.contains(&model)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_validation_o_series() {
|
||||||
|
assert!(is_valid_model("o4-mini"));
|
||||||
|
assert!(is_valid_model("o3"));
|
||||||
|
assert!(is_valid_model("o1"));
|
||||||
|
assert!(is_valid_model("gpt-4o"));
|
||||||
|
assert!(is_valid_model("gpt-3.5-turbo"));
|
||||||
|
assert!(!is_valid_model("invalid-model"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_delta_reasoning_parsing() {
|
||||||
|
let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#;
|
||||||
|
let delta: StreamDelta = serde_json::from_str(json).unwrap();
|
||||||
|
assert!(delta.content.is_none());
|
||||||
|
assert_eq!(
|
||||||
|
delta.reasoning_content,
|
||||||
|
Some("Let me think...".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stream_delta_content_parsing() {
|
||||||
|
let json = r#"{"content":"Hello","reasoning_content":null}"#;
|
||||||
|
let delta: StreamDelta = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(delta.content, Some("Hello".to_string()));
|
||||||
|
assert!(delta.reasoning_content.is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,9 @@ pub struct OpenRouterClient {
|
|||||||
api_key: String,
|
api_key: String,
|
||||||
model: String,
|
model: String,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
|
max_tokens: u32,
|
||||||
|
temperature: f32,
|
||||||
|
top_p: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
@@ -55,24 +58,30 @@ impl OpenRouterClient {
|
|||||||
/// Create new OpenRouter client
|
/// Create new OpenRouter client
|
||||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||||
let client = create_http_client(Duration::from_secs(60))?;
|
let client = create_http_client(Duration::from_secs(60))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
base_url: "https://openrouter.ai/api/v1".to_string(),
|
base_url: "https://openrouter.ai/api/v1".to_string(),
|
||||||
api_key: api_key.to_string(),
|
api_key: api_key.to_string(),
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
client,
|
client,
|
||||||
|
max_tokens: 500,
|
||||||
|
temperature: 0.7,
|
||||||
|
top_p: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create with custom base URL
|
/// Create with custom base URL
|
||||||
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
||||||
let client = create_http_client(Duration::from_secs(60))?;
|
let client = create_http_client(Duration::from_secs(60))?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
base_url: base_url.trim_end_matches('/').to_string(),
|
base_url: base_url.trim_end_matches('/').to_string(),
|
||||||
api_key: api_key.to_string(),
|
api_key: api_key.to_string(),
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
client,
|
client,
|
||||||
|
max_tokens: 500,
|
||||||
|
temperature: 0.7,
|
||||||
|
top_p: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,6 +91,21 @@ impl OpenRouterClient {
|
|||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||||
|
self.max_tokens = max_tokens;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||||
|
self.temperature = temperature;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||||
|
self.top_p = Some(top_p);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// List available models
|
/// List available models
|
||||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||||
let url = format!("{}/models", self.base_url);
|
let url = format!("{}/models", self.base_url);
|
||||||
@@ -182,8 +206,8 @@ impl OpenRouterClient {
|
|||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
model: self.model.clone(),
|
model: self.model.clone(),
|
||||||
messages,
|
messages,
|
||||||
max_tokens: Some(500),
|
max_tokens: Some(self.max_tokens),
|
||||||
temperature: Some(0.7),
|
temperature: Some(self.temperature),
|
||||||
stream: false,
|
stream: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
152
src/llm/thinking.rs
Normal file
152
src/llm/thinking.rs
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// 统一的思考状态管理器,用于管理模型思考状态的显示与隐藏
|
||||||
|
pub struct ThinkingStateManager {
|
||||||
|
is_thinking: AtomicBool,
|
||||||
|
on_start: Option<Box<dyn Fn() + Send + Sync>>,
|
||||||
|
on_end: Option<Box<dyn Fn() + Send + Sync>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ThinkingStateManager {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
is_thinking: AtomicBool::new(false),
|
||||||
|
on_start: None,
|
||||||
|
on_end: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 设置思考开始回调
|
||||||
|
pub fn on_thinking_start<F: Fn() + Send + Sync + 'static>(mut self, callback: F) -> Self {
|
||||||
|
self.on_start = Some(Box::new(callback));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 设置思考结束回调
|
||||||
|
pub fn on_thinking_end<F: Fn() + Send + Sync + 'static>(mut self, callback: F) -> Self {
|
||||||
|
self.on_end = Some(Box::new(callback));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 开始思考状态
|
||||||
|
pub fn start_thinking(&self) {
|
||||||
|
if !self.is_thinking.load(Ordering::SeqCst) {
|
||||||
|
self.is_thinking.store(true, Ordering::SeqCst);
|
||||||
|
if let Some(ref cb) = self.on_start {
|
||||||
|
cb();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 结束思考状态
|
||||||
|
pub fn end_thinking(&self) {
|
||||||
|
if self.is_thinking.load(Ordering::SeqCst) {
|
||||||
|
self.is_thinking.store(false, Ordering::SeqCst);
|
||||||
|
if let Some(ref cb) = self.on_end {
|
||||||
|
cb();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 当前是否处于思考状态
|
||||||
|
pub fn is_thinking(&self) -> bool {
|
||||||
|
self.is_thinking.load(Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ThinkingStateManager {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 线程安全的思考状态管理器引用
|
||||||
|
pub type SharedThinkingState = Arc<ThinkingStateManager>;
|
||||||
|
|
||||||
|
/// 创建带有默认控制台输出的思考状态管理器
|
||||||
|
/// 在思考开始时打印 "thinking...",在思考结束时清除该标识
|
||||||
|
pub fn create_console_thinking_state() -> SharedThinkingState {
|
||||||
|
Arc::new(
|
||||||
|
ThinkingStateManager::new()
|
||||||
|
.on_thinking_start(|| {
|
||||||
|
eprint!("\rthinking...");
|
||||||
|
})
|
||||||
|
.on_thinking_end(|| {
|
||||||
|
eprint!("\r \r");
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_thinking_state_transitions() {
|
||||||
|
let manager = ThinkingStateManager::new();
|
||||||
|
assert!(!manager.is_thinking());
|
||||||
|
|
||||||
|
manager.start_thinking();
|
||||||
|
assert!(manager.is_thinking());
|
||||||
|
|
||||||
|
manager.end_thinking();
|
||||||
|
assert!(!manager.is_thinking());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_thinking_idempotent_start() {
|
||||||
|
let manager = ThinkingStateManager::new();
|
||||||
|
manager.start_thinking();
|
||||||
|
manager.start_thinking(); // 重复调用不应触发回调两次
|
||||||
|
assert!(manager.is_thinking());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_thinking_idempotent_end() {
|
||||||
|
let manager = ThinkingStateManager::new();
|
||||||
|
manager.end_thinking(); // 未开始时结束不应触发问题
|
||||||
|
assert!(!manager.is_thinking());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_thinking_callbacks() {
|
||||||
|
let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
|
||||||
|
let events_clone = events.clone();
|
||||||
|
|
||||||
|
let manager = ThinkingStateManager::new()
|
||||||
|
.on_thinking_start(move || {
|
||||||
|
events_clone.lock().unwrap().push("start".to_string());
|
||||||
|
});
|
||||||
|
|
||||||
|
let events_clone2 = events.clone();
|
||||||
|
let manager = manager.on_thinking_end(move || {
|
||||||
|
events_clone2.lock().unwrap().push("end".to_string());
|
||||||
|
});
|
||||||
|
|
||||||
|
manager.start_thinking();
|
||||||
|
manager.end_thinking();
|
||||||
|
|
||||||
|
let recorded = events.lock().unwrap();
|
||||||
|
assert_eq!(recorded.len(), 2);
|
||||||
|
assert_eq!(recorded[0], "start");
|
||||||
|
assert_eq!(recorded[1], "end");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_console_thinking_state() {
|
||||||
|
let state = create_console_thinking_state();
|
||||||
|
assert!(!state.is_thinking());
|
||||||
|
state.start_thinking();
|
||||||
|
assert!(state.is_thinking());
|
||||||
|
state.end_thinking();
|
||||||
|
assert!(!state.is_thinking());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default() {
|
||||||
|
let manager = ThinkingStateManager::default();
|
||||||
|
assert!(!manager.is_thinking());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -85,11 +85,10 @@ impl KeyringManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_api_key(&self, provider: &str) -> Result<Option<String>> {
|
pub fn get_api_key(&self, provider: &str) -> Result<Option<String>> {
|
||||||
if let Ok(key) = env::var(ENV_API_KEY) {
|
if let Ok(key) = env::var(ENV_API_KEY)
|
||||||
if !key.is_empty() {
|
&& !key.is_empty() {
|
||||||
return Ok(Some(key));
|
return Ok(Some(key));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if !self.is_available() {
|
if !self.is_available() {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
@@ -251,8 +250,8 @@ pub fn get_default_model(provider: &str) -> &'static str {
|
|||||||
match provider {
|
match provider {
|
||||||
"openai" => "gpt-4",
|
"openai" => "gpt-4",
|
||||||
"anthropic" => "claude-3-sonnet-20240229",
|
"anthropic" => "claude-3-sonnet-20240229",
|
||||||
"kimi" => "moonshot-v1-8k",
|
"kimi" => "kimi-k2.6",
|
||||||
"deepseek" => "deepseek-chat",
|
"deepseek" => "deepseek-v4-flash",
|
||||||
"openrouter" => "openai/gpt-3.5-turbo",
|
"openrouter" => "openai/gpt-3.5-turbo",
|
||||||
"ollama" => "llama2",
|
"ollama" => "llama2",
|
||||||
_ => "",
|
_ => "",
|
||||||
|
|||||||
@@ -1,38 +1,11 @@
|
|||||||
use assert_cmd::Command;
|
use assert_cmd::cargo::cargo_bin_cmd;
|
||||||
use predicates::prelude::*;
|
use predicates::prelude::*;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
fn create_git_repo(dir: &PathBuf) -> std::process::Output {
|
|
||||||
std::process::Command::new("git")
|
|
||||||
.args(&["init"])
|
|
||||||
.current_dir(dir)
|
|
||||||
.output()
|
|
||||||
.expect("Failed to init git repo")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn configure_git_user(dir: &PathBuf) {
|
|
||||||
std::process::Command::new("git")
|
|
||||||
.args(&["config", "user.name", "Test User"])
|
|
||||||
.current_dir(dir)
|
|
||||||
.output()
|
|
||||||
.expect("Failed to configure git user name");
|
|
||||||
|
|
||||||
std::process::Command::new("git")
|
|
||||||
.args(&["config", "user.email", "test@example.com"])
|
|
||||||
.current_dir(dir)
|
|
||||||
.output()
|
|
||||||
.expect("Failed to configure git user email");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn setup_git_repo(dir: &PathBuf) {
|
|
||||||
create_git_repo(dir);
|
|
||||||
configure_git_user(dir);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn init_quicommit(config_path: &PathBuf) {
|
fn init_quicommit(config_path: &PathBuf) {
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
}
|
}
|
||||||
@@ -46,7 +19,7 @@ mod config_export {
|
|||||||
let config_path = temp_dir.path().join("config.toml");
|
let config_path = temp_dir.path().join("config.toml");
|
||||||
init_quicommit(&config_path);
|
init_quicommit(&config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["config", "export", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["config", "export", "--config", config_path.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
@@ -62,7 +35,7 @@ mod config_export {
|
|||||||
let export_path = temp_dir.path().join("exported.toml");
|
let export_path = temp_dir.path().join("exported.toml");
|
||||||
init_quicommit(&config_path);
|
init_quicommit(&config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "export",
|
"config", "export",
|
||||||
"--config", config_path.to_str().unwrap(),
|
"--config", config_path.to_str().unwrap(),
|
||||||
@@ -88,7 +61,7 @@ mod config_export {
|
|||||||
let export_path = temp_dir.path().join("encrypted.toml");
|
let export_path = temp_dir.path().join("encrypted.toml");
|
||||||
init_quicommit(&config_path);
|
init_quicommit(&config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "export",
|
"config", "export",
|
||||||
"--config", config_path.to_str().unwrap(),
|
"--config", config_path.to_str().unwrap(),
|
||||||
@@ -164,7 +137,7 @@ keep_changelog_types_english = true
|
|||||||
"#;
|
"#;
|
||||||
fs::write(&import_path, plain_config).unwrap();
|
fs::write(&import_path, plain_config).unwrap();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "import",
|
"config", "import",
|
||||||
"--config", config_path.to_str().unwrap(),
|
"--config", config_path.to_str().unwrap(),
|
||||||
@@ -175,7 +148,7 @@ keep_changelog_types_english = true
|
|||||||
.success()
|
.success()
|
||||||
.stdout(predicate::str::contains("Configuration imported"));
|
.stdout(predicate::str::contains("Configuration imported"));
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["config", "get", "llm.provider", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["config", "get", "llm.provider", "--config", config_path.to_str().unwrap()]);
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
.success()
|
.success()
|
||||||
@@ -191,14 +164,14 @@ keep_changelog_types_english = true
|
|||||||
|
|
||||||
init_quicommit(&config_path1);
|
init_quicommit(&config_path1);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "set", "llm.provider", "anthropic",
|
"config", "set", "llm.provider", "anthropic",
|
||||||
"--config", config_path1.to_str().unwrap()
|
"--config", config_path1.to_str().unwrap()
|
||||||
]);
|
]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "export",
|
"config", "export",
|
||||||
"--config", config_path1.to_str().unwrap(),
|
"--config", config_path1.to_str().unwrap(),
|
||||||
@@ -207,7 +180,7 @@ keep_changelog_types_english = true
|
|||||||
]);
|
]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "import",
|
"config", "import",
|
||||||
"--config", config_path2.to_str().unwrap(),
|
"--config", config_path2.to_str().unwrap(),
|
||||||
@@ -218,7 +191,7 @@ keep_changelog_types_english = true
|
|||||||
.success()
|
.success()
|
||||||
.stdout(predicate::str::contains("Configuration imported"));
|
.stdout(predicate::str::contains("Configuration imported"));
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["config", "get", "llm.provider", "--config", config_path2.to_str().unwrap()]);
|
cmd.args(&["config", "get", "llm.provider", "--config", config_path2.to_str().unwrap()]);
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
.success()
|
.success()
|
||||||
@@ -233,7 +206,7 @@ keep_changelog_types_english = true
|
|||||||
|
|
||||||
init_quicommit(&config_path);
|
init_quicommit(&config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "export",
|
"config", "export",
|
||||||
"--config", config_path.to_str().unwrap(),
|
"--config", config_path.to_str().unwrap(),
|
||||||
@@ -242,7 +215,7 @@ keep_changelog_types_english = true
|
|||||||
]);
|
]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "import",
|
"config", "import",
|
||||||
"--config", config_path.to_str().unwrap(),
|
"--config", config_path.to_str().unwrap(),
|
||||||
@@ -267,14 +240,14 @@ mod config_export_import_roundtrip {
|
|||||||
|
|
||||||
init_quicommit(&config_path1);
|
init_quicommit(&config_path1);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "set", "llm.model", "gpt-4-turbo",
|
"config", "set", "llm.model", "gpt-4-turbo",
|
||||||
"--config", config_path1.to_str().unwrap()
|
"--config", config_path1.to_str().unwrap()
|
||||||
]);
|
]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "export",
|
"config", "export",
|
||||||
"--config", config_path1.to_str().unwrap(),
|
"--config", config_path1.to_str().unwrap(),
|
||||||
@@ -283,7 +256,7 @@ mod config_export_import_roundtrip {
|
|||||||
]);
|
]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "import",
|
"config", "import",
|
||||||
"--config", config_path2.to_str().unwrap(),
|
"--config", config_path2.to_str().unwrap(),
|
||||||
@@ -291,7 +264,7 @@ mod config_export_import_roundtrip {
|
|||||||
]);
|
]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["config", "get", "llm.model", "--config", config_path2.to_str().unwrap()]);
|
cmd.args(&["config", "get", "llm.model", "--config", config_path2.to_str().unwrap()]);
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
.success()
|
.success()
|
||||||
@@ -308,21 +281,21 @@ mod config_export_import_roundtrip {
|
|||||||
|
|
||||||
init_quicommit(&config_path1);
|
init_quicommit(&config_path1);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "set", "llm.provider", "deepseek",
|
"config", "set", "llm.provider", "deepseek",
|
||||||
"--config", config_path1.to_str().unwrap()
|
"--config", config_path1.to_str().unwrap()
|
||||||
]);
|
]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "set", "llm.model", "deepseek-chat",
|
"config", "set", "llm.model", "deepseek-chat",
|
||||||
"--config", config_path1.to_str().unwrap()
|
"--config", config_path1.to_str().unwrap()
|
||||||
]);
|
]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "export",
|
"config", "export",
|
||||||
"--config", config_path1.to_str().unwrap(),
|
"--config", config_path1.to_str().unwrap(),
|
||||||
@@ -335,7 +308,7 @@ mod config_export_import_roundtrip {
|
|||||||
assert!(exported_content.starts_with("ENCRYPTED:"));
|
assert!(exported_content.starts_with("ENCRYPTED:"));
|
||||||
assert!(!exported_content.contains("deepseek"));
|
assert!(!exported_content.contains("deepseek"));
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&[
|
cmd.args(&[
|
||||||
"config", "import",
|
"config", "import",
|
||||||
"--config", config_path2.to_str().unwrap(),
|
"--config", config_path2.to_str().unwrap(),
|
||||||
@@ -344,13 +317,13 @@ mod config_export_import_roundtrip {
|
|||||||
]);
|
]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["config", "get", "llm.provider", "--config", config_path2.to_str().unwrap()]);
|
cmd.args(&["config", "get", "llm.provider", "--config", config_path2.to_str().unwrap()]);
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
.success()
|
.success()
|
||||||
.stdout(predicate::str::contains("deepseek"));
|
.stdout(predicate::str::contains("deepseek"));
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["config", "get", "llm.model", "--config", config_path2.to_str().unwrap()]);
|
cmd.args(&["config", "get", "llm.model", "--config", config_path2.to_str().unwrap()]);
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
.success()
|
.success()
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use assert_cmd::Command;
|
use assert_cmd::cargo::cargo_bin_cmd;
|
||||||
use predicates::prelude::*;
|
use predicates::prelude::*;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
@@ -59,7 +59,7 @@ fn setup_test_repo_with_file(dir: &PathBuf, file_name: &str, file_content: &str)
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn init_quicommit(dir: &PathBuf, config_path: &PathBuf) {
|
fn init_quicommit(dir: &PathBuf, config_path: &PathBuf) {
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(dir);
|
.current_dir(dir);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
@@ -70,7 +70,7 @@ mod cli_basic {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_help() {
|
fn test_help() {
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.arg("--help");
|
cmd.arg("--help");
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
.success()
|
.success()
|
||||||
@@ -83,7 +83,7 @@ mod cli_basic {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_version() {
|
fn test_version() {
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.arg("--version");
|
cmd.arg("--version");
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
.success()
|
.success()
|
||||||
@@ -92,7 +92,7 @@ mod cli_basic {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_no_args_shows_help() {
|
fn test_no_args_shows_help() {
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
.failure()
|
.failure()
|
||||||
.stderr(predicate::str::contains("Usage:"));
|
.stderr(predicate::str::contains("Usage:"));
|
||||||
@@ -106,7 +106,7 @@ mod cli_basic {
|
|||||||
create_git_repo(&repo_path);
|
create_git_repo(&repo_path);
|
||||||
configure_git_user(&repo_path);
|
configure_git_user(&repo_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["-vv", "init", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["-vv", "init", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
@@ -122,7 +122,7 @@ mod init_command {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let config_path = temp_dir.path().join("config.toml");
|
let config_path = temp_dir.path().join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
@@ -135,7 +135,7 @@ mod init_command {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let config_path = temp_dir.path().join("config.toml");
|
let config_path = temp_dir.path().join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
@@ -152,7 +152,7 @@ mod init_command {
|
|||||||
|
|
||||||
let config_path = repo_path.join("test_config.toml");
|
let config_path = repo_path.join("test_config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
@@ -164,11 +164,11 @@ mod init_command {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let config_path = temp_dir.path().join("config.toml");
|
let config_path = temp_dir.path().join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--reset", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--reset", "--config", config_path.to_str().unwrap()]);
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
.success()
|
.success()
|
||||||
@@ -184,7 +184,7 @@ mod profile_command {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let config_path = temp_dir.path().join("config.toml");
|
let config_path = temp_dir.path().join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["profile", "list", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["profile", "list", "--config", config_path.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
@@ -197,11 +197,11 @@ mod profile_command {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let config_path = temp_dir.path().join("config.toml");
|
let config_path = temp_dir.path().join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["profile", "list", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["profile", "list", "--config", config_path.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
@@ -218,11 +218,11 @@ mod config_command {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let config_path = temp_dir.path().join("config.toml");
|
let config_path = temp_dir.path().join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["config", "show", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["config", "show", "--config", config_path.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
@@ -235,11 +235,11 @@ mod config_command {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let config_path = temp_dir.path().join("config.toml");
|
let config_path = temp_dir.path().join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["config", "path", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["config", "path", "--config", config_path.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
@@ -256,11 +256,11 @@ mod commit_command {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let config_path = temp_dir.path().join("config.toml");
|
let config_path = temp_dir.path().join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["commit", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["commit", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(temp_dir.path());
|
.current_dir(temp_dir.path());
|
||||||
|
|
||||||
@@ -278,7 +278,7 @@ mod commit_command {
|
|||||||
let config_path = repo_path.join("config.toml");
|
let config_path = repo_path.join("config.toml");
|
||||||
init_quicommit(&repo_path, &config_path);
|
init_quicommit(&repo_path, &config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["commit", "--manual", "-m", "test: empty", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["commit", "--manual", "-m", "test: empty", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
@@ -296,7 +296,7 @@ mod commit_command {
|
|||||||
let config_path = repo_path.join("config.toml");
|
let config_path = repo_path.join("config.toml");
|
||||||
init_quicommit(&repo_path, &config_path);
|
init_quicommit(&repo_path, &config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["commit", "--manual", "-m", "test: add test file", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["commit", "--manual", "-m", "test: add test file", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
@@ -314,7 +314,7 @@ mod commit_command {
|
|||||||
let config_path = repo_path.join("config.toml");
|
let config_path = repo_path.join("config.toml");
|
||||||
init_quicommit(&repo_path, &config_path);
|
init_quicommit(&repo_path, &config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["commit", "--date", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["commit", "--date", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
@@ -322,6 +322,22 @@ mod commit_command {
|
|||||||
.success()
|
.success()
|
||||||
.stdout(predicate::str::contains("Dry run"));
|
.stdout(predicate::str::contains("Dry run"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_commit_with_think_flag() {
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let repo_path = temp_dir.path().to_path_buf();
|
||||||
|
setup_test_repo_with_file(&repo_path, "test.txt", "Hello, World!");
|
||||||
|
|
||||||
|
let config_path = repo_path.join("config.toml");
|
||||||
|
init_quicommit(&repo_path, &config_path);
|
||||||
|
|
||||||
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
|
cmd.args(&["commit", "--think", "--manual", "-m", "test: think flag", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
|
cmd.assert().success();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mod tag_command {
|
mod tag_command {
|
||||||
@@ -332,11 +348,11 @@ mod tag_command {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let config_path = temp_dir.path().join("config.toml");
|
let config_path = temp_dir.path().join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["tag", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["tag", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(temp_dir.path());
|
.current_dir(temp_dir.path());
|
||||||
|
|
||||||
@@ -358,7 +374,7 @@ mod tag_command {
|
|||||||
let config_path = repo_path.join("config.toml");
|
let config_path = repo_path.join("config.toml");
|
||||||
init_quicommit(&repo_path, &config_path);
|
init_quicommit(&repo_path, &config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["tag", "--name", "v0.1.0", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["tag", "--name", "v0.1.0", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
@@ -366,6 +382,26 @@ mod tag_command {
|
|||||||
.success()
|
.success()
|
||||||
.stdout(predicate::str::contains("v0.1.0"));
|
.stdout(predicate::str::contains("v0.1.0"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tag_with_think_flag() {
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
let repo_path = temp_dir.path().to_path_buf();
|
||||||
|
setup_git_repo(&repo_path);
|
||||||
|
|
||||||
|
create_test_file(&repo_path, "test.txt", "content");
|
||||||
|
stage_file(&repo_path, "test.txt");
|
||||||
|
create_commit(&repo_path, "feat: initial commit");
|
||||||
|
|
||||||
|
let config_path = repo_path.join("config.toml");
|
||||||
|
init_quicommit(&repo_path, &config_path);
|
||||||
|
|
||||||
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
|
cmd.args(&["tag", "--think", "--name", "v0.2.0", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
|
cmd.assert().success();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mod changelog_command {
|
mod changelog_command {
|
||||||
@@ -382,7 +418,7 @@ mod changelog_command {
|
|||||||
|
|
||||||
init_quicommit(&repo_path, &config_path);
|
init_quicommit(&repo_path, &config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["changelog", "--init", "--output", changelog_path.to_str().unwrap(), "--config", config_path.to_str().unwrap()])
|
cmd.args(&["changelog", "--init", "--output", changelog_path.to_str().unwrap(), "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
@@ -404,7 +440,7 @@ mod changelog_command {
|
|||||||
let config_path = repo_path.join("config.toml");
|
let config_path = repo_path.join("config.toml");
|
||||||
init_quicommit(&repo_path, &config_path);
|
init_quicommit(&repo_path, &config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["changelog", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["changelog", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
@@ -421,7 +457,7 @@ mod cross_platform {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let config_path = temp_dir.path().join("subdir").join("config.toml");
|
let config_path = temp_dir.path().join("subdir").join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
@@ -435,7 +471,7 @@ mod cross_platform {
|
|||||||
fs::create_dir_all(&space_dir).unwrap();
|
fs::create_dir_all(&space_dir).unwrap();
|
||||||
let config_path = space_dir.join("config.toml");
|
let config_path = space_dir.join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
@@ -449,7 +485,7 @@ mod cross_platform {
|
|||||||
fs::create_dir_all(&unicode_dir).unwrap();
|
fs::create_dir_all(&unicode_dir).unwrap();
|
||||||
let config_path = unicode_dir.join("config.toml");
|
let config_path = unicode_dir.join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
@@ -523,7 +559,7 @@ mod validators {
|
|||||||
let config_path = repo_path.join("config.toml");
|
let config_path = repo_path.join("config.toml");
|
||||||
init_quicommit(&repo_path, &config_path);
|
init_quicommit(&repo_path, &config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["commit", "--manual", "-m", "invalid commit message without type", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["commit", "--manual", "-m", "invalid commit message without type", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
@@ -541,7 +577,7 @@ mod validators {
|
|||||||
let config_path = repo_path.join("config.toml");
|
let config_path = repo_path.join("config.toml");
|
||||||
init_quicommit(&repo_path, &config_path);
|
init_quicommit(&repo_path, &config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["commit", "--manual", "-m", "feat: add new feature", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["commit", "--manual", "-m", "feat: add new feature", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
@@ -563,7 +599,7 @@ mod subcommands {
|
|||||||
let config_path = repo_path.join("config.toml");
|
let config_path = repo_path.join("config.toml");
|
||||||
init_quicommit(&repo_path, &config_path);
|
init_quicommit(&repo_path, &config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["c", "--manual", "-m", "fix: test", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["c", "--manual", "-m", "fix: test", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
@@ -577,7 +613,7 @@ mod subcommands {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let config_path = temp_dir.path().join("config.toml");
|
let config_path = temp_dir.path().join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["i", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["i", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
@@ -590,11 +626,11 @@ mod subcommands {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let config_path = temp_dir.path().join("config.toml");
|
let config_path = temp_dir.path().join("config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["p", "list", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["p", "list", "--config", config_path.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
@@ -611,7 +647,7 @@ mod edge_cases {
|
|||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let non_existent_config = temp_dir.path().join("non_existent_config.toml");
|
let non_existent_config = temp_dir.path().join("non_existent_config.toml");
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["config", "show", "--config", non_existent_config.to_str().unwrap()]);
|
cmd.args(&["config", "show", "--config", non_existent_config.to_str().unwrap()]);
|
||||||
|
|
||||||
cmd.assert()
|
cmd.assert()
|
||||||
@@ -627,11 +663,11 @@ mod edge_cases {
|
|||||||
let repo_path = temp_dir.path().to_path_buf();
|
let repo_path = temp_dir.path().to_path_buf();
|
||||||
|
|
||||||
let config_path = repo_path.join("config.toml");
|
let config_path = repo_path.join("config.toml");
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
cmd.args(&["init", "--yes", "--config", config_path.to_str().unwrap()]);
|
||||||
cmd.assert().success();
|
cmd.assert().success();
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["commit", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["commit", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
@@ -649,7 +685,7 @@ mod edge_cases {
|
|||||||
let config_path = repo_path.join("config.toml");
|
let config_path = repo_path.join("config.toml");
|
||||||
init_quicommit(&repo_path, &config_path);
|
init_quicommit(&repo_path, &config_path);
|
||||||
|
|
||||||
let mut cmd = Command::cargo_bin("quicommit").unwrap();
|
let mut cmd = cargo_bin_cmd!("quicommit");
|
||||||
cmd.args(&["commit", "--manual", "-m", "", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
cmd.args(&["commit", "--manual", "-m", "", "--dry-run", "--yes", "--config", config_path.to_str().unwrap()])
|
||||||
.current_dir(&repo_path);
|
.current_dir(&repo_path);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user