style: 格式化代码并优化导入顺序
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
use super::thinking::ThinkingStateManager;
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use super::{LlmProvider, create_http_client};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
@@ -522,13 +522,13 @@ impl AnthropicClient {
|
||||
}
|
||||
}
|
||||
"content_block_stop" => {
|
||||
if in_thinking {
|
||||
if let Some(state) = thinking_state {
|
||||
state.end_thinking();
|
||||
}
|
||||
in_thinking = false;
|
||||
}
|
||||
}
|
||||
if in_thinking {
|
||||
if let Some(state) = thinking_state {
|
||||
state.end_thinking();
|
||||
}
|
||||
in_thinking = false;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@@ -618,10 +618,7 @@ mod tests {
|
||||
let json = r#"{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}"#;
|
||||
let event: SseEvent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(event.event_type, "content_block_start");
|
||||
assert_eq!(
|
||||
event.content_block.unwrap().content_type,
|
||||
"thinking"
|
||||
);
|
||||
assert_eq!(event.content_block.unwrap().content_type, "thinking");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::thinking::ThinkingStateManager;
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use super::{LlmProvider, create_http_client};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
@@ -459,34 +459,39 @@ impl DeepSeekClient {
|
||||
for choice in &chunk.choices {
|
||||
// 处理 reasoning_content
|
||||
if let Some(ref reasoning) = choice.delta.reasoning_content
|
||||
&& !reasoning.is_empty() {
|
||||
if !has_reasoning {
|
||||
has_reasoning = true;
|
||||
if let Some(state) = thinking_state {
|
||||
state.start_thinking();
|
||||
}
|
||||
&& !reasoning.is_empty()
|
||||
{
|
||||
if !has_reasoning {
|
||||
has_reasoning = true;
|
||||
if let Some(state) = thinking_state {
|
||||
state.start_thinking();
|
||||
}
|
||||
// reasoning_content 不对外输出,仅用于内部状态判断
|
||||
continue;
|
||||
}
|
||||
// reasoning_content 不对外输出,仅用于内部状态判断
|
||||
continue;
|
||||
}
|
||||
|
||||
// 处理 content
|
||||
if let Some(ref content) = choice.delta.content
|
||||
&& !content.is_empty() {
|
||||
// reasoning 结束,content 开始出现时移除 thinking 标识
|
||||
if has_reasoning && !has_content
|
||||
&& let Some(state) = thinking_state {
|
||||
state.end_thinking();
|
||||
}
|
||||
has_content = true;
|
||||
content_buffer.push_str(content);
|
||||
&& !content.is_empty()
|
||||
{
|
||||
// reasoning 结束,content 开始出现时移除 thinking 标识
|
||||
if has_reasoning
|
||||
&& !has_content
|
||||
&& let Some(state) = thinking_state
|
||||
{
|
||||
state.end_thinking();
|
||||
}
|
||||
has_content = true;
|
||||
content_buffer.push_str(content);
|
||||
}
|
||||
|
||||
// 检查 finish_reason
|
||||
if let Some(ref reason) = choice.finish_reason
|
||||
&& reason == "stop" {
|
||||
stream_ended = true;
|
||||
}
|
||||
&& reason == "stop"
|
||||
{
|
||||
stream_ended = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
@@ -612,9 +617,6 @@ mod tests {
|
||||
let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#;
|
||||
let delta: StreamDelta = serde_json::from_str(json).unwrap();
|
||||
assert!(delta.content.is_none());
|
||||
assert_eq!(
|
||||
delta.reasoning_content,
|
||||
Some("Let me think...".to_string())
|
||||
);
|
||||
assert_eq!(delta.reasoning_content, Some("Let me think...".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::thinking::ThinkingStateManager;
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use super::{LlmProvider, create_http_client};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
@@ -431,30 +431,35 @@ impl KimiClient {
|
||||
Ok(chunk) => {
|
||||
for choice in &chunk.choices {
|
||||
if let Some(ref reasoning) = choice.delta.reasoning_content
|
||||
&& !reasoning.is_empty() {
|
||||
if !has_reasoning {
|
||||
has_reasoning = true;
|
||||
if let Some(state) = thinking_state {
|
||||
state.start_thinking();
|
||||
}
|
||||
&& !reasoning.is_empty()
|
||||
{
|
||||
if !has_reasoning {
|
||||
has_reasoning = true;
|
||||
if let Some(state) = thinking_state {
|
||||
state.start_thinking();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(ref content) = choice.delta.content
|
||||
&& !content.is_empty() {
|
||||
if has_reasoning && !has_content
|
||||
&& let Some(state) = thinking_state {
|
||||
state.end_thinking();
|
||||
}
|
||||
has_content = true;
|
||||
content_buffer.push_str(content);
|
||||
&& !content.is_empty()
|
||||
{
|
||||
if has_reasoning
|
||||
&& !has_content
|
||||
&& let Some(state) = thinking_state
|
||||
{
|
||||
state.end_thinking();
|
||||
}
|
||||
has_content = true;
|
||||
content_buffer.push_str(content);
|
||||
}
|
||||
|
||||
if let Some(ref reason) = choice.finish_reason
|
||||
&& reason == "stop" {
|
||||
stream_ended = true;
|
||||
}
|
||||
&& reason == "stop"
|
||||
{
|
||||
stream_ended = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
|
||||
243
src/llm/mod.rs
243
src/llm/mod.rs
@@ -1,21 +1,21 @@
|
||||
use anyhow::{bail, Context, Result};
|
||||
use crate::config::Language;
|
||||
use anyhow::{Context, Result, bail};
|
||||
use async_trait::async_trait;
|
||||
use std::time::Duration;
|
||||
use crate::config::Language;
|
||||
|
||||
pub mod anthropic;
|
||||
pub mod deepseek;
|
||||
pub mod kimi;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
pub mod anthropic;
|
||||
pub mod kimi;
|
||||
pub mod deepseek;
|
||||
pub mod openrouter;
|
||||
pub mod thinking;
|
||||
|
||||
pub use anthropic::AnthropicClient;
|
||||
pub use deepseek::DeepSeekClient;
|
||||
pub use kimi::KimiClient;
|
||||
pub use ollama::OllamaClient;
|
||||
pub use openai::OpenAiClient;
|
||||
pub use anthropic::AnthropicClient;
|
||||
pub use kimi::KimiClient;
|
||||
pub use deepseek::DeepSeekClient;
|
||||
pub use openrouter::OpenRouterClient;
|
||||
|
||||
/// LLM provider trait
|
||||
@@ -23,13 +23,13 @@ pub use openrouter::OpenRouterClient;
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
/// Generate text from prompt
|
||||
async fn generate(&self, prompt: &str) -> Result<String>;
|
||||
|
||||
|
||||
/// Generate with system prompt
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String>;
|
||||
|
||||
|
||||
/// Check if provider is available
|
||||
async fn is_available(&self) -> bool;
|
||||
|
||||
|
||||
/// Get provider name
|
||||
fn name(&self) -> &str;
|
||||
}
|
||||
@@ -84,13 +84,14 @@ impl LlmClient {
|
||||
let api_key = manager.get_api_key();
|
||||
|
||||
let provider: Box<dyn LlmProvider> = match provider {
|
||||
"ollama" => {
|
||||
Box::new(OllamaClient::new(&base_url, model)
|
||||
"ollama" => Box::new(
|
||||
OllamaClient::new(&base_url, model)
|
||||
.with_max_tokens(client_config.max_tokens)
|
||||
.with_temperature(client_config.temperature))
|
||||
}
|
||||
.with_temperature(client_config.temperature),
|
||||
),
|
||||
"openai" => {
|
||||
let key = api_key.as_ref()
|
||||
let key = api_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenAI API key not configured"))?;
|
||||
let thinking_state = if thinking_enabled {
|
||||
Some(thinking::create_console_thinking_state())
|
||||
@@ -108,7 +109,8 @@ impl LlmClient {
|
||||
Box::new(client)
|
||||
}
|
||||
"anthropic" => {
|
||||
let key = api_key.as_ref()
|
||||
let key = api_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Anthropic API key not configured"))?;
|
||||
let thinking_state = if thinking_enabled {
|
||||
Some(thinking::create_console_thinking_state())
|
||||
@@ -128,7 +130,8 @@ impl LlmClient {
|
||||
Box::new(client)
|
||||
}
|
||||
"kimi" => {
|
||||
let key = api_key.as_ref()
|
||||
let key = api_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Kimi API key not configured"))?;
|
||||
let thinking_state = if thinking_enabled {
|
||||
Some(thinking::create_console_thinking_state())
|
||||
@@ -146,7 +149,8 @@ impl LlmClient {
|
||||
Box::new(client)
|
||||
}
|
||||
"deepseek" => {
|
||||
let key = api_key.as_ref()
|
||||
let key = api_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("DeepSeek API key not configured"))?;
|
||||
let thinking_state = if thinking_enabled {
|
||||
Some(thinking::create_console_thinking_state())
|
||||
@@ -164,12 +168,15 @@ impl LlmClient {
|
||||
Box::new(client)
|
||||
}
|
||||
"openrouter" => {
|
||||
let key = api_key.as_ref()
|
||||
let key = api_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not configured"))?;
|
||||
Box::new(OpenRouterClient::with_base_url(key, model, &base_url)?
|
||||
.with_max_tokens(client_config.max_tokens)
|
||||
.with_temperature(client_config.temperature)
|
||||
.with_timeout(client_config.timeout)?)
|
||||
Box::new(
|
||||
OpenRouterClient::with_base_url(key, model, &base_url)?
|
||||
.with_max_tokens(client_config.max_tokens)
|
||||
.with_temperature(client_config.temperature)
|
||||
.with_timeout(client_config.timeout)?,
|
||||
)
|
||||
}
|
||||
_ => bail!("Unknown LLM provider: {}", provider),
|
||||
};
|
||||
@@ -196,7 +203,7 @@ impl LlmClient {
|
||||
language: Language,
|
||||
) -> Result<GeneratedCommit> {
|
||||
let system_prompt = get_commit_system_prompt(format, language);
|
||||
|
||||
|
||||
// Add language instruction to the prompt
|
||||
let language_instruction = match language {
|
||||
Language::Chinese => "\n\n请用中文生成提交消息。",
|
||||
@@ -207,10 +214,13 @@ impl LlmClient {
|
||||
Language::German => "\n\nBitte generieren Sie die Commit-Nachricht auf Deutsch.",
|
||||
Language::English => "",
|
||||
};
|
||||
|
||||
|
||||
let prompt = format!("{}{}", diff, language_instruction);
|
||||
let response = self.provider.generate_with_system(system_prompt, &prompt).await?;
|
||||
|
||||
let response = self
|
||||
.provider
|
||||
.generate_with_system(system_prompt, &prompt)
|
||||
.await?;
|
||||
|
||||
self.parse_commit_response(&response, format)
|
||||
}
|
||||
|
||||
@@ -223,7 +233,7 @@ impl LlmClient {
|
||||
) -> Result<String> {
|
||||
let system_prompt = get_tag_system_prompt(language);
|
||||
let commits_text = commits.join("\n");
|
||||
|
||||
|
||||
// Add language instruction to the prompt
|
||||
let language_instruction = match language {
|
||||
Language::Chinese => "\n\n请用中文生成标签消息。",
|
||||
@@ -234,10 +244,15 @@ impl LlmClient {
|
||||
Language::German => "\n\nBitte generieren Sie die Tag-Nachricht auf Deutsch.",
|
||||
Language::English => "",
|
||||
};
|
||||
|
||||
let prompt = format!("Version: {}\n\nCommits:\n{}{}", version, commits_text, language_instruction);
|
||||
|
||||
self.provider.generate_with_system(system_prompt, &prompt).await
|
||||
|
||||
let prompt = format!(
|
||||
"Version: {}\n\nCommits:\n{}{}",
|
||||
version, commits_text, language_instruction
|
||||
);
|
||||
|
||||
self.provider
|
||||
.generate_with_system(system_prompt, &prompt)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Generate changelog entry
|
||||
@@ -248,13 +263,13 @@ impl LlmClient {
|
||||
language: Language,
|
||||
) -> Result<String> {
|
||||
let system_prompt = get_changelog_system_prompt(language);
|
||||
|
||||
|
||||
let commits_text = commits
|
||||
.iter()
|
||||
.map(|(t, m)| format!("- [{}] {}", t, m))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
|
||||
// Add language instruction to the prompt
|
||||
let language_instruction = match language {
|
||||
Language::Chinese => "\n\n请用中文生成变更日志。",
|
||||
@@ -265,10 +280,15 @@ impl LlmClient {
|
||||
Language::German => "\n\nBitte generieren Sie das Changelog auf Deutsch.",
|
||||
Language::English => "",
|
||||
};
|
||||
|
||||
let prompt = format!("Version: {}\n\nCommits:\n{}{}", version, commits_text, language_instruction);
|
||||
|
||||
self.provider.generate_with_system(system_prompt, &prompt).await
|
||||
|
||||
let prompt = format!(
|
||||
"Version: {}\n\nCommits:\n{}{}",
|
||||
version, commits_text, language_instruction
|
||||
);
|
||||
|
||||
self.provider
|
||||
.generate_with_system(system_prompt, &prompt)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Check if provider is available
|
||||
@@ -277,8 +297,16 @@ impl LlmClient {
|
||||
}
|
||||
|
||||
/// Parse commit response from LLM
|
||||
fn parse_commit_response(&self, response: &str, format: crate::config::CommitFormat) -> Result<GeneratedCommit> {
|
||||
let lines: Vec<&str> = response.lines()
|
||||
fn parse_commit_response(
|
||||
&self,
|
||||
response: &str,
|
||||
format: crate::config::CommitFormat,
|
||||
) -> Result<GeneratedCommit> {
|
||||
// Clean markdown code fences from the response
|
||||
let cleaned = Self::strip_code_fences(response);
|
||||
|
||||
let lines: Vec<&str> = cleaned
|
||||
.lines()
|
||||
.map(|l| l.trim())
|
||||
.filter(|l| !l.is_empty())
|
||||
.collect();
|
||||
@@ -295,28 +323,89 @@ impl LlmClient {
|
||||
);
|
||||
}
|
||||
|
||||
let first_line = lines[0];
|
||||
|
||||
// Find the line most likely to be the commit subject
|
||||
let first_line = Self::find_commit_subject_line(&lines, format);
|
||||
|
||||
// Parse based on format
|
||||
match format {
|
||||
crate::config::CommitFormat::Conventional => {
|
||||
self.parse_conventional_commit(first_line, lines)
|
||||
self.parse_conventional_commit(first_line, &lines, response)
|
||||
}
|
||||
crate::config::CommitFormat::Commitlint => {
|
||||
self.parse_commitlint_commit(first_line, lines)
|
||||
self.parse_commitlint_commit(first_line, &lines, response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove surrounding markdown code fences (```) from LLM output
|
||||
fn strip_code_fences(response: &str) -> String {
|
||||
let mut lines: Vec<&str> = response.lines().collect();
|
||||
|
||||
// Strip leading fence lines (``` or ```lang)
|
||||
while lines.first().map_or(false, |l| l.trim().starts_with("```")) {
|
||||
lines.remove(0);
|
||||
}
|
||||
|
||||
// Strip trailing fence lines
|
||||
while lines.last().map_or(false, |l| l.trim() == "```") {
|
||||
lines.pop();
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
/// Find the line that is most likely the commit subject among extracted lines
|
||||
fn find_commit_subject_line<'a>(
|
||||
lines: &[&'a str],
|
||||
format: crate::config::CommitFormat,
|
||||
) -> &'a str {
|
||||
let valid_types = crate::utils::validators::get_commit_types(matches!(
|
||||
format,
|
||||
crate::config::CommitFormat::Commitlint
|
||||
));
|
||||
|
||||
// First pass: line starting with a known type that also has proper syntax
|
||||
// (e.g. "type:", "type(scope):", "type!:")
|
||||
for &line in lines {
|
||||
let trimmed = line.trim();
|
||||
for &t in valid_types {
|
||||
if let Some(rest) = trimmed.strip_prefix(t) {
|
||||
if rest.starts_with(':') || rest.starts_with('(') || rest.starts_with("!:") {
|
||||
return trimmed;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: any line containing a colon (generic "prefix: description")
|
||||
for &line in lines {
|
||||
if line.contains(':') {
|
||||
return line.trim();
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: return the first line as-is
|
||||
lines[0].trim()
|
||||
}
|
||||
|
||||
fn parse_conventional_commit(
|
||||
&self,
|
||||
first_line: &str,
|
||||
lines: Vec<&str>,
|
||||
lines: &[&str],
|
||||
raw_response: &str,
|
||||
) -> Result<GeneratedCommit> {
|
||||
// Parse: type(scope)!: description
|
||||
let parts: Vec<&str> = first_line.splitn(2, ':').collect();
|
||||
if parts.len() != 2 {
|
||||
bail!("Invalid conventional commit format: missing colon");
|
||||
let preview: String = raw_response.chars().take(300).collect();
|
||||
bail!(
|
||||
"Invalid conventional commit format: missing colon.\n\
|
||||
Parsed subject line: '{}'\n\
|
||||
Raw response preview: '{}'\n\
|
||||
Expected: <type>[optional scope]: <description>",
|
||||
first_line,
|
||||
preview
|
||||
);
|
||||
}
|
||||
|
||||
let type_part = parts[0];
|
||||
@@ -339,7 +428,7 @@ impl LlmClient {
|
||||
};
|
||||
|
||||
// Extract body and footer
|
||||
let (body, footer) = self.extract_body_footer(&lines);
|
||||
let (body, footer) = self.extract_body_footer(lines);
|
||||
|
||||
Ok(GeneratedCommit {
|
||||
commit_type,
|
||||
@@ -354,12 +443,21 @@ impl LlmClient {
|
||||
fn parse_commitlint_commit(
|
||||
&self,
|
||||
first_line: &str,
|
||||
lines: Vec<&str>,
|
||||
lines: &[&str],
|
||||
raw_response: &str,
|
||||
) -> Result<GeneratedCommit> {
|
||||
// Similar parsing but with commitlint rules
|
||||
let parts: Vec<&str> = first_line.splitn(2, ':').collect();
|
||||
if parts.len() != 2 {
|
||||
bail!("Invalid commit format: missing colon");
|
||||
let preview: String = raw_response.chars().take(300).collect();
|
||||
bail!(
|
||||
"Invalid commit format: missing colon.\n\
|
||||
Parsed subject line: '{}'\n\
|
||||
Raw response preview: '{}'\n\
|
||||
Expected: <type>[optional scope]: <subject>",
|
||||
first_line,
|
||||
preview
|
||||
);
|
||||
}
|
||||
|
||||
let type_part = parts[0];
|
||||
@@ -405,8 +503,14 @@ impl LlmClient {
|
||||
}
|
||||
|
||||
// Look for footer markers
|
||||
let footer_markers = ["BREAKING CHANGE:", "Closes", "Fixes", "Refs", "Co-authored-by:"];
|
||||
|
||||
let footer_markers = [
|
||||
"BREAKING CHANGE:",
|
||||
"Closes",
|
||||
"Fixes",
|
||||
"Refs",
|
||||
"Co-authored-by:",
|
||||
];
|
||||
|
||||
let mut body_lines = vec![];
|
||||
let mut footer_lines = vec![];
|
||||
let mut in_footer = false;
|
||||
@@ -415,7 +519,7 @@ impl LlmClient {
|
||||
if footer_markers.iter().any(|m| line.starts_with(m)) {
|
||||
in_footer = true;
|
||||
}
|
||||
|
||||
|
||||
if in_footer {
|
||||
footer_lines.push(*line);
|
||||
} else {
|
||||
@@ -485,17 +589,34 @@ pub(crate) fn create_http_client(timeout: Duration) -> Result<reqwest::Client> {
|
||||
}
|
||||
|
||||
/// Get commit system prompt based on format and language
|
||||
fn get_commit_system_prompt(format: crate::config::CommitFormat, language: Language) -> &'static str {
|
||||
fn get_commit_system_prompt(
|
||||
format: crate::config::CommitFormat,
|
||||
language: Language,
|
||||
) -> &'static str {
|
||||
match (format, language) {
|
||||
(crate::config::CommitFormat::Conventional, Language::Chinese) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ZH,
|
||||
(crate::config::CommitFormat::Conventional, Language::Japanese) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_JA,
|
||||
(crate::config::CommitFormat::Conventional, Language::Korean) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_KO,
|
||||
(crate::config::CommitFormat::Conventional, Language::Spanish) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ES,
|
||||
(crate::config::CommitFormat::Conventional, Language::French) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_FR,
|
||||
(crate::config::CommitFormat::Conventional, Language::German) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT_DE,
|
||||
(crate::config::CommitFormat::Conventional, Language::Chinese) => {
|
||||
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ZH
|
||||
}
|
||||
(crate::config::CommitFormat::Conventional, Language::Japanese) => {
|
||||
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_JA
|
||||
}
|
||||
(crate::config::CommitFormat::Conventional, Language::Korean) => {
|
||||
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_KO
|
||||
}
|
||||
(crate::config::CommitFormat::Conventional, Language::Spanish) => {
|
||||
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_ES
|
||||
}
|
||||
(crate::config::CommitFormat::Conventional, Language::French) => {
|
||||
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_FR
|
||||
}
|
||||
(crate::config::CommitFormat::Conventional, Language::German) => {
|
||||
CONVENTIONAL_COMMIT_SYSTEM_PROMPT_DE
|
||||
}
|
||||
(crate::config::CommitFormat::Conventional, _) => CONVENTIONAL_COMMIT_SYSTEM_PROMPT,
|
||||
(crate::config::CommitFormat::Commitlint, Language::Chinese) => COMMITLINT_SYSTEM_PROMPT_ZH,
|
||||
(crate::config::CommitFormat::Commitlint, Language::Japanese) => COMMITLINT_SYSTEM_PROMPT_JA,
|
||||
(crate::config::CommitFormat::Commitlint, Language::Japanese) => {
|
||||
COMMITLINT_SYSTEM_PROMPT_JA
|
||||
}
|
||||
(crate::config::CommitFormat::Commitlint, Language::Korean) => COMMITLINT_SYSTEM_PROMPT_KO,
|
||||
(crate::config::CommitFormat::Commitlint, Language::Spanish) => COMMITLINT_SYSTEM_PROMPT_ES,
|
||||
(crate::config::CommitFormat::Commitlint, Language::French) => COMMITLINT_SYSTEM_PROMPT_FR,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use super::{LlmProvider, create_http_client};
|
||||
use anyhow::{Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -50,8 +50,8 @@ struct ModelInfo {
|
||||
impl OllamaClient {
|
||||
/// Create new Ollama client
|
||||
pub fn new(base_url: &str, model: &str) -> Self {
|
||||
let client = create_http_client(Duration::from_secs(120))
|
||||
.expect("Failed to create HTTP client");
|
||||
let client =
|
||||
create_http_client(Duration::from_secs(120)).expect("Failed to create HTTP client");
|
||||
|
||||
Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
@@ -65,8 +65,7 @@ impl OllamaClient {
|
||||
|
||||
/// Set timeout
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.client = create_http_client(timeout)
|
||||
.expect("Failed to create HTTP client");
|
||||
self.client = create_http_client(timeout).expect("Failed to create HTTP client");
|
||||
self
|
||||
}
|
||||
|
||||
@@ -88,49 +87,51 @@ impl OllamaClient {
|
||||
/// List available models
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
|
||||
let response = self.client
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to list Ollama models")?;
|
||||
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Ollama API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
|
||||
let result: ListModelsResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Ollama response")?;
|
||||
|
||||
|
||||
Ok(result.models.into_iter().map(|m| m.name).collect())
|
||||
}
|
||||
|
||||
/// Pull a model
|
||||
pub async fn pull_model(&self, model: &str) -> Result<()> {
|
||||
let url = format!("{}/api/pull", self.base_url);
|
||||
|
||||
|
||||
let request = serde_json::json!({
|
||||
"name": model,
|
||||
"stream": false,
|
||||
});
|
||||
|
||||
let response = self.client
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to pull Ollama model")?;
|
||||
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Ollama pull error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -151,13 +152,13 @@ impl LlmProvider for OllamaClient {
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let url = format!("{}/api/generate", self.base_url);
|
||||
|
||||
|
||||
let system = if system.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(system.to_string())
|
||||
};
|
||||
|
||||
|
||||
let request = GenerateRequest {
|
||||
model: self.model.clone(),
|
||||
prompt: user.to_string(),
|
||||
@@ -168,31 +169,32 @@ impl LlmProvider for OllamaClient {
|
||||
num_predict: Some(self.max_tokens),
|
||||
},
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to Ollama")?;
|
||||
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Ollama API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
|
||||
let result: GenerateResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Ollama response")?;
|
||||
|
||||
|
||||
Ok(result.response.trim().to_string())
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
|
||||
|
||||
match self.client.get(&url).send().await {
|
||||
Ok(response) => response.status().is_success(),
|
||||
Err(_) => false,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::thinking::ThinkingStateManager;
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use super::{LlmProvider, create_http_client};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
@@ -411,7 +411,8 @@ impl OpenAiClient {
|
||||
if let Some(ref content) = choice.delta.content
|
||||
&& !content.is_empty()
|
||||
{
|
||||
if has_reasoning && !has_content
|
||||
if has_reasoning
|
||||
&& !has_content
|
||||
&& let Some(state) = thinking_state
|
||||
{
|
||||
state.end_thinking();
|
||||
@@ -465,12 +466,7 @@ pub struct AzureOpenAiClient {
|
||||
}
|
||||
|
||||
impl AzureOpenAiClient {
|
||||
pub fn new(
|
||||
endpoint: &str,
|
||||
api_key: &str,
|
||||
deployment: &str,
|
||||
api_version: &str,
|
||||
) -> Result<Self> {
|
||||
pub fn new(endpoint: &str, api_key: &str, deployment: &str, api_version: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
Ok(Self {
|
||||
@@ -642,10 +638,7 @@ mod tests {
|
||||
let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#;
|
||||
let delta: StreamDelta = serde_json::from_str(json).unwrap();
|
||||
assert!(delta.content.is_none());
|
||||
assert_eq!(
|
||||
delta.reasoning_content,
|
||||
Some("Let me think...".to_string())
|
||||
);
|
||||
assert_eq!(delta.reasoning_content, Some("Let me think...".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,281 +1,286 @@
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
/// OpenRouter API client
|
||||
pub struct OpenRouterClient {
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
top_p: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: ApiError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiError {
|
||||
message: String,
|
||||
#[serde(rename = "type")]
|
||||
error_type: String,
|
||||
}
|
||||
|
||||
impl OpenRouterClient {
|
||||
/// Create new OpenRouter client
|
||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: "https://openrouter.ai/api/v1".to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
top_p: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with custom base URL
|
||||
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
top_p: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set timeout
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||
self.client = create_http_client(timeout)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.max_tokens = max_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||
self.temperature = temperature;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||
self.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
/// List available models
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
let url = format!("{}/models", self.base_url);
|
||||
|
||||
let response = self.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("HTTP-Referer", "https://quicommit.dev")
|
||||
.header("X-Title", "QuiCommit")
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to list OpenRouter models")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
bail!("OpenRouter API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModelsResponse {
|
||||
data: Vec<Model>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Model {
|
||||
id: String,
|
||||
}
|
||||
|
||||
let result: ModelsResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse OpenRouter response")?;
|
||||
|
||||
Ok(result.data.into_iter().map(|m| m.id).collect())
|
||||
}
|
||||
|
||||
/// Validate API key
|
||||
pub async fn validate_key(&self) -> Result<bool> {
|
||||
match self.list_models().await {
|
||||
Ok(_) => Ok(true),
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
if err_str.contains("401") || err_str.contains("Unauthorized") {
|
||||
Ok(false)
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenRouterClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let mut messages = vec![];
|
||||
|
||||
if !system.is_empty() {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: system.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
});
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
self.validate_key().await.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"openrouter"
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenRouterClient {
|
||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
max_tokens: Some(self.max_tokens),
|
||||
temperature: Some(self.temperature),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("HTTP-Referer", "https://quicommit.dev")
|
||||
.header("X-Title", "QuiCommit")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to OpenRouter")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
// Try to parse error
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!("OpenRouter API error: {} ({})", error.error.message, error.error.error_type);
|
||||
}
|
||||
|
||||
bail!("OpenRouter API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let result: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse OpenRouter response")?;
|
||||
|
||||
result.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content.trim().to_string())
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Popular OpenRouter models
|
||||
pub const OPENROUTER_MODELS: &[&str] = &[
|
||||
"openai/gpt-3.5-turbo",
|
||||
"openai/gpt-4",
|
||||
"openai/gpt-4-turbo",
|
||||
"anthropic/claude-3-opus",
|
||||
"anthropic/claude-3-sonnet",
|
||||
"anthropic/claude-3-haiku",
|
||||
"google/gemini-pro",
|
||||
"meta-llama/llama-2-70b-chat",
|
||||
"mistralai/mixtral-8x7b-instruct",
|
||||
"01-ai/yi-34b-chat",
|
||||
];
|
||||
|
||||
/// Check if a model name is valid
|
||||
pub fn is_valid_model(_model: &str) -> bool {
|
||||
// Since OpenRouter supports many models, we'll allow any model name
|
||||
// but provide some popular ones as suggestions
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_validation() {
|
||||
assert!(is_valid_model("openai/gpt-4"));
|
||||
assert!(is_valid_model("custom/model"));
|
||||
}
|
||||
}
|
||||
use super::{LlmProvider, create_http_client};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
/// OpenRouter API client
|
||||
pub struct OpenRouterClient {
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
top_p: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: ApiError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiError {
|
||||
message: String,
|
||||
#[serde(rename = "type")]
|
||||
error_type: String,
|
||||
}
|
||||
|
||||
impl OpenRouterClient {
|
||||
/// Create new OpenRouter client
|
||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: "https://openrouter.ai/api/v1".to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
top_p: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with custom base URL
|
||||
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
top_p: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set timeout
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||
self.client = create_http_client(timeout)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.max_tokens = max_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||
self.temperature = temperature;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||
self.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
/// List available models
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
let url = format!("{}/models", self.base_url);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("HTTP-Referer", "https://quicommit.dev")
|
||||
.header("X-Title", "QuiCommit")
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to list OpenRouter models")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
bail!("OpenRouter API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModelsResponse {
|
||||
data: Vec<Model>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Model {
|
||||
id: String,
|
||||
}
|
||||
|
||||
let result: ModelsResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse OpenRouter response")?;
|
||||
|
||||
Ok(result.data.into_iter().map(|m| m.id).collect())
|
||||
}
|
||||
|
||||
/// Validate API key
|
||||
pub async fn validate_key(&self) -> Result<bool> {
|
||||
match self.list_models().await {
|
||||
Ok(_) => Ok(true),
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
if err_str.contains("401") || err_str.contains("Unauthorized") {
|
||||
Ok(false)
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenRouterClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
let messages = vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
}];
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let mut messages = vec![];
|
||||
|
||||
if !system.is_empty() {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: system.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
});
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
self.validate_key().await.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"openrouter"
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenRouterClient {
|
||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
max_tokens: Some(self.max_tokens),
|
||||
temperature: Some(self.temperature),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("HTTP-Referer", "https://quicommit.dev")
|
||||
.header("X-Title", "QuiCommit")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to OpenRouter")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
// Try to parse error
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!(
|
||||
"OpenRouter API error: {} ({})",
|
||||
error.error.message,
|
||||
error.error.error_type
|
||||
);
|
||||
}
|
||||
|
||||
bail!("OpenRouter API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let result: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse OpenRouter response")?;
|
||||
|
||||
result
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content.trim().to_string())
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Popular OpenRouter models
|
||||
pub const OPENROUTER_MODELS: &[&str] = &[
|
||||
"openai/gpt-3.5-turbo",
|
||||
"openai/gpt-4",
|
||||
"openai/gpt-4-turbo",
|
||||
"anthropic/claude-3-opus",
|
||||
"anthropic/claude-3-sonnet",
|
||||
"anthropic/claude-3-haiku",
|
||||
"google/gemini-pro",
|
||||
"meta-llama/llama-2-70b-chat",
|
||||
"mistralai/mixtral-8x7b-instruct",
|
||||
"01-ai/yi-34b-chat",
|
||||
];
|
||||
|
||||
/// Check if a model name is valid
|
||||
pub fn is_valid_model(_model: &str) -> bool {
|
||||
// Since OpenRouter supports many models, we'll allow any model name
|
||||
// but provide some popular ones as suggestions
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_validation() {
|
||||
assert!(is_valid_model("openai/gpt-4"));
|
||||
assert!(is_valid_model("custom/model"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
/// 统一的思考状态管理器,用于管理模型思考状态的显示与隐藏
|
||||
pub struct ThinkingStateManager {
|
||||
@@ -115,10 +115,9 @@ mod tests {
|
||||
let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
|
||||
let events_clone = events.clone();
|
||||
|
||||
let manager = ThinkingStateManager::new()
|
||||
.on_thinking_start(move || {
|
||||
events_clone.lock().unwrap().push("start".to_string());
|
||||
});
|
||||
let manager = ThinkingStateManager::new().on_thinking_start(move || {
|
||||
events_clone.lock().unwrap().push("start".to_string());
|
||||
});
|
||||
|
||||
let events_clone2 = events.clone();
|
||||
let manager = manager.on_thinking_end(move || {
|
||||
|
||||
Reference in New Issue
Block a user