LLM支持优化
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
use super::thinking::ThinkingStateManager;
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Anthropic Claude API client
|
||||
@@ -9,6 +11,12 @@ pub struct AnthropicClient {
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
thinking_enabled: bool,
|
||||
thinking_budget_tokens: u32,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
top_p: Option<f32>,
|
||||
thinking_state: Option<Arc<ThinkingStateManager>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -17,24 +25,58 @@ struct MessagesRequest {
|
||||
max_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_p: Option<f32>,
|
||||
messages: Vec<AnthropicMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system: Option<String>,
|
||||
system: Option<Vec<SystemContent>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
thinking: Option<ThinkingConfig>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
struct SystemContent {
|
||||
#[serde(rename = "type")]
|
||||
content_type: String,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ThinkingConfig {
|
||||
#[serde(rename = "type")]
|
||||
thinking_type: String,
|
||||
budget_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
struct AnthropicMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
content: AnthropicContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[serde(untagged)]
|
||||
enum AnthropicContent {
|
||||
Text(String),
|
||||
Blocks(Vec<ContentBlock>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
struct ContentBlock {
|
||||
#[serde(rename = "type")]
|
||||
content_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
text: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessagesResponse {
|
||||
content: Vec<ContentBlock>,
|
||||
content: Vec<ResponseContentBlock>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ContentBlock {
|
||||
struct ResponseContentBlock {
|
||||
#[serde(rename = "type")]
|
||||
content_type: String,
|
||||
text: String,
|
||||
@@ -52,31 +94,112 @@ struct AnthropicError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
// --- Streaming SSE event structures ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SseEvent {
|
||||
#[serde(rename = "type")]
|
||||
event_type: String,
|
||||
#[serde(default)]
|
||||
message: Option<SseMessage>,
|
||||
#[serde(default)]
|
||||
index: Option<u32>,
|
||||
#[serde(default)]
|
||||
content_block: Option<SseContentBlock>,
|
||||
#[serde(default)]
|
||||
delta: Option<SseDelta>,
|
||||
#[serde(default)]
|
||||
usage: Option<SseUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<Vec<SseContentBlock>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SseContentBlock {
|
||||
#[serde(rename = "type")]
|
||||
content_type: String,
|
||||
#[serde(default)]
|
||||
thinking: Option<String>,
|
||||
#[serde(default)]
|
||||
text: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SseDelta {
|
||||
#[serde(rename = "type")]
|
||||
delta_type: Option<String>,
|
||||
#[serde(default)]
|
||||
thinking: Option<String>,
|
||||
#[serde(default)]
|
||||
text: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SseUsage {
|
||||
#[serde(default)]
|
||||
output_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
/// Create new Anthropic client
|
||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
thinking_enabled: false,
|
||||
thinking_budget_tokens: 1024,
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
top_p: None,
|
||||
thinking_state: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set timeout
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||
self.client = create_http_client(timeout)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// List available models
|
||||
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
||||
self.thinking_enabled = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_thinking_budget_tokens(mut self, budget_tokens: u32) -> Self {
|
||||
self.thinking_budget_tokens = budget_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.max_tokens = max_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||
self.temperature = temperature;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||
self.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
|
||||
self.thinking_state = Some(state);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
// Anthropic doesn't have a models API endpoint, return predefined list
|
||||
Ok(ANTHROPIC_MODELS.iter().map(|&m| m.to_string()).collect())
|
||||
}
|
||||
|
||||
/// Validate API key
|
||||
pub async fn validate_key(&self) -> Result<bool> {
|
||||
let url = "https://api.anthropic.com/v1/messages";
|
||||
|
||||
@@ -84,14 +207,18 @@ impl AnthropicClient {
|
||||
model: self.model.clone(),
|
||||
max_tokens: 5,
|
||||
temperature: Some(0.0),
|
||||
top_p: None,
|
||||
messages: vec![AnthropicMessage {
|
||||
role: "user".to_string(),
|
||||
content: "Hi".to_string(),
|
||||
content: AnthropicContent::Text("Hi".to_string()),
|
||||
}],
|
||||
system: None,
|
||||
thinking: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
@@ -124,25 +251,28 @@ impl LlmProvider for AnthropicClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
let messages = vec![AnthropicMessage {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
content: AnthropicContent::Text(prompt.to_string()),
|
||||
}];
|
||||
|
||||
self.messages_request(messages, None).await
|
||||
|
||||
self.messages_request_with_retry(messages, None).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let messages = vec![AnthropicMessage {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
content: AnthropicContent::Text(user.to_string()),
|
||||
}];
|
||||
|
||||
|
||||
let system = if system.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(system.to_string())
|
||||
Some(vec![SystemContent {
|
||||
content_type: "text".to_string(),
|
||||
text: system.to_string(),
|
||||
}])
|
||||
};
|
||||
|
||||
self.messages_request(messages, system).await
|
||||
|
||||
self.messages_request_with_retry(messages, system).await
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
@@ -155,22 +285,81 @@ impl LlmProvider for AnthropicClient {
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
async fn messages_request_with_retry(
|
||||
&self,
|
||||
messages: Vec<AnthropicMessage>,
|
||||
system: Option<Vec<SystemContent>>,
|
||||
) -> Result<String> {
|
||||
let mut last_error = None;
|
||||
|
||||
for attempt in 1..=3 {
|
||||
match self
|
||||
.messages_request(messages.clone(), system.clone())
|
||||
.await
|
||||
{
|
||||
Ok(result) => return Ok(result),
|
||||
Err(e) => {
|
||||
let err_msg = e.to_string();
|
||||
let is_retryable = err_msg.contains("timeout")
|
||||
|| err_msg.contains("connection")
|
||||
|| err_msg.contains("temporary")
|
||||
|| err_msg.contains("5")
|
||||
&& (err_msg.contains("500")
|
||||
|| err_msg.contains("502")
|
||||
|| err_msg.contains("503")
|
||||
|| err_msg.contains("504"));
|
||||
|
||||
if !is_retryable || attempt == 3 {
|
||||
last_error = Some(e);
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
|
||||
}
|
||||
|
||||
async fn messages_request(
|
||||
&self,
|
||||
messages: Vec<AnthropicMessage>,
|
||||
system: Option<String>,
|
||||
system: Option<Vec<SystemContent>>,
|
||||
) -> Result<String> {
|
||||
if self.thinking_enabled {
|
||||
self.streaming_messages_request(messages, system).await
|
||||
} else {
|
||||
self.non_streaming_messages_request(messages, system).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn non_streaming_messages_request(
|
||||
&self,
|
||||
messages: Vec<AnthropicMessage>,
|
||||
system: Option<Vec<SystemContent>>,
|
||||
) -> Result<String> {
|
||||
let url = "https://api.anthropic.com/v1/messages";
|
||||
|
||||
|
||||
let temperature = if self.temperature == 0.0 {
|
||||
None
|
||||
} else {
|
||||
Some(self.temperature)
|
||||
};
|
||||
|
||||
let request = MessagesRequest {
|
||||
model: self.model.clone(),
|
||||
max_tokens: 500,
|
||||
temperature: Some(0.7),
|
||||
max_tokens: self.max_tokens,
|
||||
temperature,
|
||||
top_p: self.top_p,
|
||||
messages,
|
||||
system,
|
||||
thinking: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
@@ -179,35 +368,205 @@ impl AnthropicClient {
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to Anthropic")?;
|
||||
|
||||
|
||||
let status = response.status();
|
||||
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
// Try to parse error
|
||||
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!("Anthropic API error: {} ({})", error.error.message, error.error.error_type);
|
||||
bail!(
|
||||
"Anthropic API error: {} ({})",
|
||||
error.error.message,
|
||||
error.error.error_type
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
bail!("Anthropic API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
|
||||
let result: MessagesResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Anthropic response")?;
|
||||
|
||||
result.content
|
||||
|
||||
result
|
||||
.content
|
||||
.into_iter()
|
||||
.find(|c| c.content_type == "text")
|
||||
.map(|c| c.text.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("No text response from Anthropic"))
|
||||
}
|
||||
|
||||
/// Streaming request for thinking mode, filters thinking content blocks
|
||||
async fn streaming_messages_request(
|
||||
&self,
|
||||
messages: Vec<AnthropicMessage>,
|
||||
system: Option<Vec<SystemContent>>,
|
||||
) -> Result<String> {
|
||||
let url = "https://api.anthropic.com/v1/messages";
|
||||
|
||||
let thinking = ThinkingConfig {
|
||||
thinking_type: "enabled".to_string(),
|
||||
budget_tokens: self.thinking_budget_tokens,
|
||||
};
|
||||
|
||||
// max_tokens must exceed budget_tokens
|
||||
let max_tokens = (self.max_tokens).max(self.thinking_budget_tokens + 100);
|
||||
|
||||
let request = MessagesRequest {
|
||||
model: self.model.clone(),
|
||||
max_tokens,
|
||||
temperature: None, // must be omitted for thinking mode
|
||||
top_p: None,
|
||||
messages,
|
||||
system,
|
||||
thinking: Some(thinking),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "text/event-stream")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send streaming request to Anthropic")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!(
|
||||
"Anthropic API error: {} ({})",
|
||||
error.error.message,
|
||||
error.error.error_type
|
||||
);
|
||||
}
|
||||
|
||||
bail!("Anthropic API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let mut content_buffer = String::new();
|
||||
let mut in_thinking = false;
|
||||
let mut has_reasoning = false;
|
||||
let mut has_content = false;
|
||||
|
||||
let thinking_state = self.thinking_state.as_ref();
|
||||
|
||||
let mut byte_stream = response.bytes_stream();
|
||||
let mut line_buffer = String::new();
|
||||
|
||||
use futures_util::StreamExt;
|
||||
|
||||
while let Some(chunk) = byte_stream.next().await {
|
||||
let chunk = chunk.context("Failed to read streaming response chunk")?;
|
||||
let chunk_str =
|
||||
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
|
||||
|
||||
line_buffer.push_str(&chunk_str);
|
||||
|
||||
while let Some(line_end) = line_buffer.find('\n') {
|
||||
let line = line_buffer[..line_end].trim().to_string();
|
||||
line_buffer = line_buffer[line_end + 1..].to_string();
|
||||
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse SSE event line
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if let Ok(event) = serde_json::from_str::<SseEvent>(data) {
|
||||
match event.event_type.as_str() {
|
||||
"content_block_start" => {
|
||||
if let Some(ref block) = event.content_block {
|
||||
if block.content_type == "thinking" {
|
||||
in_thinking = true;
|
||||
if !has_reasoning {
|
||||
has_reasoning = true;
|
||||
if let Some(state) = thinking_state {
|
||||
state.start_thinking();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_delta" => {
|
||||
if let Some(ref delta) = event.delta {
|
||||
// Thinking delta - ignore content but track state
|
||||
if delta.thinking.is_some() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Text delta - collect
|
||||
if in_thinking && delta.text.is_some() {
|
||||
// Transition from thinking to text
|
||||
if let Some(state) = thinking_state {
|
||||
state.end_thinking();
|
||||
}
|
||||
in_thinking = false;
|
||||
}
|
||||
if let Some(ref text) = delta.text
|
||||
&& !text.is_empty()
|
||||
{
|
||||
has_content = true;
|
||||
content_buffer.push_str(text);
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_stop" => {
|
||||
if in_thinking {
|
||||
if let Some(state) = thinking_state {
|
||||
state.end_thinking();
|
||||
}
|
||||
in_thinking = false;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure thinking state is ended
|
||||
if let Some(state) = thinking_state {
|
||||
state.end_thinking();
|
||||
}
|
||||
|
||||
let result = content_buffer.trim().to_string();
|
||||
|
||||
if result.is_empty() {
|
||||
if has_reasoning && !has_content {
|
||||
bail!(
|
||||
"Anthropic returned thinking content but no final answer. \
|
||||
The model may have entered an incomplete thinking state. \
|
||||
Please try again or disable thinking mode."
|
||||
);
|
||||
}
|
||||
bail!(
|
||||
"No response from Anthropic. \
|
||||
If thinking mode is enabled, try disabling it or ensure the model supports it."
|
||||
);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Available Anthropic models
|
||||
/// Available Anthropic models (Claude 4 series with extended thinking)
|
||||
pub const ANTHROPIC_MODELS: &[&str] = &[
|
||||
"claude-opus-4-7",
|
||||
"claude-sonnet-4-6",
|
||||
"claude-haiku-4-5",
|
||||
// Legacy models
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
@@ -216,7 +575,6 @@ pub const ANTHROPIC_MODELS: &[&str] = &[
|
||||
"claude-instant-1.2",
|
||||
];
|
||||
|
||||
/// Check if a model name is valid
|
||||
pub fn is_valid_model(model: &str) -> bool {
|
||||
ANTHROPIC_MODELS.contains(&model)
|
||||
}
|
||||
@@ -226,8 +584,61 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_validation() {
|
||||
fn test_model_validation_claude4() {
|
||||
assert!(is_valid_model("claude-opus-4-7"));
|
||||
assert!(is_valid_model("claude-sonnet-4-6"));
|
||||
assert!(is_valid_model("claude-haiku-4-5"));
|
||||
assert!(is_valid_model("claude-3-sonnet-20240229"));
|
||||
assert!(!is_valid_model("invalid-model"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thinking_config_serialization() {
|
||||
let config = ThinkingConfig {
|
||||
thinking_type: "enabled".to_string(),
|
||||
budget_tokens: 2048,
|
||||
};
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
assert!(json.contains(r#""type":"enabled""#));
|
||||
assert!(json.contains(r#""budget_tokens":2048"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_content_serialization() {
|
||||
let content = SystemContent {
|
||||
content_type: "text".to_string(),
|
||||
text: "You are helpful.".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&content).unwrap();
|
||||
assert!(json.contains(r#""type":"text""#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_event_parsing_content_block_start() {
|
||||
let json = r#"{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}"#;
|
||||
let event: SseEvent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(event.event_type, "content_block_start");
|
||||
assert_eq!(
|
||||
event.content_block.unwrap().content_type,
|
||||
"thinking"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_event_parsing_text_delta() {
|
||||
let json = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
|
||||
let event: SseEvent = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(event.event_type, "content_block_delta");
|
||||
assert_eq!(event.delta.unwrap().text, Some("Hello".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_content_text() {
|
||||
let msg = AnthropicMessage {
|
||||
role: "user".to_string(),
|
||||
content: AnthropicContent::Text("Hello".to_string()),
|
||||
};
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(json.contains(r#""content":"Hello""#));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use super::thinking::ThinkingStateManager;
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// DeepSeek API client
|
||||
@@ -11,6 +13,10 @@ pub struct DeepSeekClient {
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
thinking_enabled: bool,
|
||||
reasoning_effort: Option<String>,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
thinking_state: Option<Arc<ThinkingStateManager>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -21,9 +27,17 @@ struct ChatCompletionRequest {
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
presence_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
frequency_penalty: Option<f32>,
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
thinking: Option<ThinkingConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
reasoning_effort: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -32,10 +46,12 @@ struct ThinkingConfig {
|
||||
thinking_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -50,6 +66,29 @@ struct Choice {
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
// --- Streaming response structures ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StreamChunk {
|
||||
choices: Vec<StreamChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StreamChoice {
|
||||
delta: StreamDelta,
|
||||
#[serde(default)]
|
||||
finish_reason: Option<String>,
|
||||
index: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct StreamDelta {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: ApiError,
|
||||
@@ -63,22 +102,24 @@ struct ApiError {
|
||||
}
|
||||
|
||||
impl DeepSeekClient {
|
||||
/// Create new DeepSeek client
|
||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
let client = create_http_client(Duration::from_secs(300))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: "https://api.deepseek.com/".to_string(),
|
||||
base_url: "https://api.deepseek.com".to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
thinking_enabled: false,
|
||||
reasoning_effort: None,
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
thinking_state: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with custom base URL
|
||||
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
let client = create_http_client(Duration::from_secs(300))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
@@ -86,26 +127,48 @@ impl DeepSeekClient {
|
||||
model: model.to_string(),
|
||||
client,
|
||||
thinking_enabled: false,
|
||||
reasoning_effort: None,
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
thinking_state: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set timeout
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||
self.client = create_http_client(timeout)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Enable or disable thinking mode
|
||||
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
||||
self.thinking_enabled = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
/// List available models
|
||||
pub fn with_reasoning_effort(mut self, effort: Option<String>) -> Self {
|
||||
self.reasoning_effort = effort;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.max_tokens = max_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||
self.temperature = temperature;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
|
||||
self.thinking_state = Some(state);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
let url = format!("{}/models", self.base_url);
|
||||
|
||||
let response = self.client
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.send()
|
||||
@@ -120,11 +183,11 @@ impl DeepSeekClient {
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModelsResponse {
|
||||
data: Vec<Model>,
|
||||
data: Vec<ModelId>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Model {
|
||||
struct ModelId {
|
||||
id: String,
|
||||
}
|
||||
|
||||
@@ -136,7 +199,6 @@ impl DeepSeekClient {
|
||||
Ok(result.data.into_iter().map(|m| m.id).collect())
|
||||
}
|
||||
|
||||
/// Validate API key
|
||||
pub async fn validate_key(&self) -> Result<bool> {
|
||||
match self.list_models().await {
|
||||
Ok(_) => Ok(true),
|
||||
@@ -155,14 +217,13 @@ impl DeepSeekClient {
|
||||
#[async_trait]
|
||||
impl LlmProvider for DeepSeekClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
},
|
||||
];
|
||||
let messages = vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
reasoning_content: None,
|
||||
}];
|
||||
|
||||
self.chat_completion(messages).await
|
||||
self.chat_completion_with_retry(messages).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
@@ -172,15 +233,17 @@ impl LlmProvider for DeepSeekClient {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: system.to_string(),
|
||||
reasoning_content: None,
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
reasoning_content: None,
|
||||
});
|
||||
|
||||
self.chat_completion(messages).await
|
||||
self.chat_completion_with_retry(messages).await
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
@@ -193,6 +256,38 @@ impl LlmProvider for DeepSeekClient {
|
||||
}
|
||||
|
||||
impl DeepSeekClient {
|
||||
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let mut last_error = None;
|
||||
|
||||
for attempt in 1..=3 {
|
||||
match self.chat_completion(messages.clone()).await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(e) => {
|
||||
let err_msg = e.to_string();
|
||||
// 网络临时错误才重试
|
||||
let is_retryable = err_msg.contains("timeout")
|
||||
|| err_msg.contains("connection")
|
||||
|| err_msg.contains("temporary")
|
||||
|| err_msg.contains("5")
|
||||
&& (err_msg.contains("500")
|
||||
|| err_msg.contains("502")
|
||||
|| err_msg.contains("503")
|
||||
|| err_msg.contains("504"));
|
||||
|
||||
if !is_retryable || attempt == 3 {
|
||||
last_error = Some(e);
|
||||
break;
|
||||
}
|
||||
|
||||
// 指数退避
|
||||
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
@@ -204,20 +299,59 @@ impl DeepSeekClient {
|
||||
None
|
||||
};
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
max_tokens: Some(500),
|
||||
temperature: Some(0.7),
|
||||
stream: false,
|
||||
thinking,
|
||||
// 思考模式下,temperature/top_p 等参数不应传递
|
||||
// 非思考模式下可以正常传递
|
||||
let (temperature, max_tokens, top_p, presence_penalty, frequency_penalty) =
|
||||
if self.thinking_enabled {
|
||||
(None, Some(self.max_tokens), None, None, None)
|
||||
} else {
|
||||
(
|
||||
Some(self.temperature),
|
||||
Some(self.max_tokens),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
let reasoning_effort = if self.thinking_enabled {
|
||||
self.reasoning_effort.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages: messages.clone(),
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
stream: self.thinking_enabled,
|
||||
thinking,
|
||||
reasoning_effort,
|
||||
};
|
||||
|
||||
if self.thinking_enabled {
|
||||
self.streaming_chat_completion(&url, &request).await
|
||||
} else {
|
||||
self.non_streaming_chat_completion(&url, &request).await
|
||||
}
|
||||
}
|
||||
|
||||
/// 非流式请求(非思考模式)
|
||||
async fn non_streaming_chat_completion(
|
||||
&self,
|
||||
url: &str,
|
||||
request: &ChatCompletionRequest,
|
||||
) -> Result<String> {
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.json(request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to DeepSeek")?;
|
||||
@@ -228,7 +362,11 @@ impl DeepSeekClient {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!("DeepSeek API error: {} ({})", error.error.message, error.error.error_type);
|
||||
bail!(
|
||||
"DeepSeek API error: {} ({})",
|
||||
error.error.message,
|
||||
error.error.error_type
|
||||
);
|
||||
}
|
||||
|
||||
bail!("DeepSeek API error: {} - {}", status, text);
|
||||
@@ -239,34 +377,165 @@ impl DeepSeekClient {
|
||||
.await
|
||||
.context("Failed to parse DeepSeek response")?;
|
||||
|
||||
result.choices
|
||||
result
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| {
|
||||
let content = c.message.content.trim().to_string();
|
||||
if content.is_empty() {
|
||||
c.reasoning_content
|
||||
.map(|r| r.trim().to_string())
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
content
|
||||
}
|
||||
})
|
||||
.map(|c| c.message.content.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!(
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from DeepSeek"))
|
||||
}
|
||||
|
||||
/// 流式请求(思考模式),处理 reasoning_content 和 content
|
||||
async fn streaming_chat_completion(
|
||||
&self,
|
||||
url: &str,
|
||||
request: &ChatCompletionRequest,
|
||||
) -> Result<String> {
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "text/event-stream")
|
||||
.json(request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send streaming request to DeepSeek")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!(
|
||||
"DeepSeek API error: {} ({})",
|
||||
error.error.message,
|
||||
error.error.error_type
|
||||
);
|
||||
}
|
||||
|
||||
bail!("DeepSeek API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let mut content_buffer = String::new();
|
||||
let mut has_reasoning = false;
|
||||
let mut has_content = false;
|
||||
let mut stream_ended = false;
|
||||
|
||||
let thinking_state = self.thinking_state.as_ref();
|
||||
|
||||
let mut byte_stream = response.bytes_stream();
|
||||
let mut line_buffer = String::new();
|
||||
|
||||
use futures_util::StreamExt;
|
||||
|
||||
while let Some(chunk) = byte_stream.next().await {
|
||||
let chunk = chunk.context("Failed to read streaming response chunk")?;
|
||||
let chunk_str =
|
||||
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
|
||||
|
||||
line_buffer.push_str(&chunk_str);
|
||||
|
||||
// 处理完整行
|
||||
while let Some(line_end) = line_buffer.find('\n') {
|
||||
let line = line_buffer[..line_end].trim().to_string();
|
||||
line_buffer = line_buffer[line_end + 1..].to_string();
|
||||
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// SSE 格式:data: {...} 或 data: [DONE]
|
||||
if line == "data: [DONE]" {
|
||||
stream_ended = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(json_str) = line.strip_prefix("data: ") {
|
||||
match serde_json::from_str::<StreamChunk>(json_str) {
|
||||
Ok(chunk) => {
|
||||
for choice in &chunk.choices {
|
||||
// 处理 reasoning_content
|
||||
if let Some(ref reasoning) = choice.delta.reasoning_content
|
||||
&& !reasoning.is_empty() {
|
||||
if !has_reasoning {
|
||||
has_reasoning = true;
|
||||
if let Some(state) = thinking_state {
|
||||
state.start_thinking();
|
||||
}
|
||||
}
|
||||
// reasoning_content 不对外输出,仅用于内部状态判断
|
||||
continue;
|
||||
}
|
||||
|
||||
// 处理 content
|
||||
if let Some(ref content) = choice.delta.content
|
||||
&& !content.is_empty() {
|
||||
// reasoning 结束,content 开始出现时移除 thinking 标识
|
||||
if has_reasoning && !has_content
|
||||
&& let Some(state) = thinking_state {
|
||||
state.end_thinking();
|
||||
}
|
||||
has_content = true;
|
||||
content_buffer.push_str(content);
|
||||
}
|
||||
|
||||
// 检查 finish_reason
|
||||
if let Some(ref reason) = choice.finish_reason
|
||||
&& reason == "stop" {
|
||||
stream_ended = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// 忽略无法解析的行(可能是心跳或注释)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if stream_ended {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 确保思考状态已结束
|
||||
if let Some(state) = thinking_state {
|
||||
state.end_thinking();
|
||||
}
|
||||
|
||||
let result = content_buffer.trim().to_string();
|
||||
|
||||
if result.is_empty() {
|
||||
if has_reasoning && !has_content {
|
||||
bail!(
|
||||
"DeepSeek returned reasoning content but no final answer. \
|
||||
The model may have entered an incomplete thinking state. \
|
||||
Please try again or disable thinking mode."
|
||||
);
|
||||
}
|
||||
bail!(
|
||||
"No response from DeepSeek. \
|
||||
If thinking mode is enabled, try disabling it or ensure the model supports it."
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Available DeepSeek models
|
||||
/// 可用 DeepSeek 模型列表
|
||||
/// deepseek-chat / deepseek-reasoner 将于 2026-07-24 停用,推荐使用 V4 系列
|
||||
pub const DEEPSEEK_MODELS: &[&str] = &[
|
||||
"deepseek-v4-flash",
|
||||
"deepseek-v4-pro",
|
||||
// 兼容旧版模型 ID(将于 2026-07-24 停用)
|
||||
"deepseek-chat",
|
||||
"deepseek-reasoner",
|
||||
];
|
||||
|
||||
/// Check if a model name is valid
|
||||
pub fn is_valid_model(model: &str) -> bool {
|
||||
DEEPSEEK_MODELS.contains(&model)
|
||||
}
|
||||
@@ -276,9 +545,76 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_validation() {
|
||||
fn test_model_validation_v4() {
|
||||
assert!(is_valid_model("deepseek-v4-flash"));
|
||||
assert!(is_valid_model("deepseek-v4-pro"));
|
||||
assert!(is_valid_model("deepseek-chat"));
|
||||
assert!(is_valid_model("deepseek-reasoner"));
|
||||
assert!(!is_valid_model("invalid-model"));
|
||||
assert!(!is_valid_model("deepseek-v3"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_builder_defaults() {
|
||||
let client = DeepSeekClient::new("test-key", "deepseek-v4-flash").unwrap();
|
||||
assert!(!client.thinking_enabled);
|
||||
assert_eq!(client.max_tokens, 500);
|
||||
assert_eq!(client.temperature, 0.7);
|
||||
assert!(client.reasoning_effort.is_none());
|
||||
assert!(client.thinking_state.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_builder_with_thinking() {
|
||||
let client = DeepSeekClient::new("test-key", "deepseek-v4-flash")
|
||||
.unwrap()
|
||||
.with_thinking(true)
|
||||
.with_reasoning_effort(Some("high".to_string()))
|
||||
.with_max_tokens(1000)
|
||||
.with_temperature(0.5);
|
||||
|
||||
assert!(client.thinking_enabled);
|
||||
assert_eq!(client.reasoning_effort, Some("high".to_string()));
|
||||
assert_eq!(client.max_tokens, 1000);
|
||||
assert_eq!(client.temperature, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thinking_config_serialization() {
|
||||
let config = ThinkingConfig {
|
||||
thinking_type: "enabled".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
assert_eq!(json, r#"{"type":"enabled"}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_serialization_without_reasoning() {
|
||||
let msg = Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hello".to_string(),
|
||||
reasoning_content: None,
|
||||
};
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(!json.contains("reasoning_content"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_delta_parsing() {
|
||||
let json = r#"{"content":"Hello","reasoning_content":null}"#;
|
||||
let delta: StreamDelta = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(delta.content, Some("Hello".to_string()));
|
||||
assert!(delta.reasoning_content.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_delta_reasoning_only() {
|
||||
let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#;
|
||||
let delta: StreamDelta = serde_json::from_str(json).unwrap();
|
||||
assert!(delta.content.is_none());
|
||||
assert_eq!(
|
||||
delta.reasoning_content,
|
||||
Some("Let me think...".to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
855
src/llm/kimi.rs
855
src/llm/kimi.rs
@@ -1,284 +1,571 @@
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Kimi API client (Moonshot AI)
|
||||
pub struct KimiClient {
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
thinking_enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
thinking: Option<ThinkingConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ThinkingConfig {
|
||||
#[serde(rename = "type")]
|
||||
thinking_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: Message,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: ApiError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiError {
|
||||
message: String,
|
||||
#[serde(rename = "type")]
|
||||
error_type: String,
|
||||
}
|
||||
|
||||
impl KimiClient {
|
||||
/// Create new Kimi client
|
||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: "https://api.moonshot.cn/v1".to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
thinking_enabled: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with custom base URL
|
||||
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
thinking_enabled: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set timeout
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||
self.client = create_http_client(timeout)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Enable or disable thinking mode
|
||||
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
||||
self.thinking_enabled = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
/// List available models
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
let url = format!("{}/models", self.base_url);
|
||||
|
||||
let response = self.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to list Kimi models")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
bail!("Kimi API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModelsResponse {
|
||||
data: Vec<Model>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Model {
|
||||
id: String,
|
||||
}
|
||||
|
||||
let result: ModelsResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Kimi response")?;
|
||||
|
||||
Ok(result.data.into_iter().map(|m| m.id).collect())
|
||||
}
|
||||
|
||||
/// Validate API key
|
||||
pub async fn validate_key(&self) -> Result<bool> {
|
||||
match self.list_models().await {
|
||||
Ok(_) => Ok(true),
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
if err_str.contains("401") || err_str.contains("Unauthorized") {
|
||||
Ok(false)
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for KimiClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let mut messages = vec![];
|
||||
|
||||
if !system.is_empty() {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: system.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
});
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
self.validate_key().await.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"kimi"
|
||||
}
|
||||
}
|
||||
|
||||
impl KimiClient {
|
||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
let thinking = if self.thinking_enabled {
|
||||
Some(ThinkingConfig {
|
||||
thinking_type: "enabled".to_string(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
max_tokens: Some(500),
|
||||
temperature: Some(1.0),
|
||||
stream: false,
|
||||
thinking,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to Kimi")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!("Kimi API error: {} ({})", error.error.message, error.error.error_type);
|
||||
}
|
||||
|
||||
bail!("Kimi API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let result: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Kimi response")?;
|
||||
|
||||
result.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| {
|
||||
let content = c.message.content.trim().to_string();
|
||||
if content.is_empty() {
|
||||
c.reasoning_content
|
||||
.map(|r| r.trim().to_string())
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
content
|
||||
}
|
||||
})
|
||||
.filter(|s| !s.is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!(
|
||||
"No response from Kimi. \
|
||||
If thinking mode is enabled, try disabling it or ensure the model supports it."
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Available Kimi models
|
||||
pub const KIMI_MODELS: &[&str] = &[
|
||||
"moonshot-v1-8k",
|
||||
"moonshot-v1-32k",
|
||||
"moonshot-v1-128k",
|
||||
];
|
||||
|
||||
/// Check if a model name is valid
|
||||
pub fn is_valid_model(model: &str) -> bool {
|
||||
KIMI_MODELS.contains(&model)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_validation() {
|
||||
assert!(is_valid_model("moonshot-v1-8k"));
|
||||
assert!(!is_valid_model("invalid-model"));
|
||||
}
|
||||
}
|
||||
use super::thinking::ThinkingStateManager;
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Kimi API client (Moonshot AI)
|
||||
pub struct KimiClient {
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
thinking_enabled: bool,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
thinking_state: Option<Arc<ThinkingStateManager>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
thinking: Option<ThinkingConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ThinkingConfig {
|
||||
#[serde(rename = "type")]
|
||||
thinking_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: Message,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
// --- Streaming response structures ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StreamChunk {
|
||||
choices: Vec<StreamChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StreamChoice {
|
||||
delta: StreamDelta,
|
||||
#[serde(default)]
|
||||
finish_reason: Option<String>,
|
||||
index: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct StreamDelta {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: ApiError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiError {
|
||||
message: String,
|
||||
#[serde(rename = "type")]
|
||||
error_type: String,
|
||||
}
|
||||
|
||||
impl KimiClient {
|
||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(300))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: "https://api.moonshot.cn/v1".to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
thinking_enabled: false,
|
||||
max_tokens: 500,
|
||||
temperature: 1.0,
|
||||
thinking_state: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(300))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
thinking_enabled: false,
|
||||
max_tokens: 500,
|
||||
temperature: 1.0,
|
||||
thinking_state: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||
self.client = create_http_client(timeout)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
||||
self.thinking_enabled = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.max_tokens = max_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||
self.temperature = temperature;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
|
||||
self.thinking_state = Some(state);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
let url = format!("{}/models", self.base_url);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to list Kimi models")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
bail!("Kimi API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModelsResponse {
|
||||
data: Vec<ModelId>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModelId {
|
||||
id: String,
|
||||
}
|
||||
|
||||
let result: ModelsResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Kimi response")?;
|
||||
|
||||
Ok(result.data.into_iter().map(|m| m.id).collect())
|
||||
}
|
||||
|
||||
pub async fn validate_key(&self) -> Result<bool> {
|
||||
match self.list_models().await {
|
||||
Ok(_) => Ok(true),
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
if err_str.contains("401") || err_str.contains("Unauthorized") {
|
||||
Ok(false)
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for KimiClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
let messages = vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
reasoning_content: None,
|
||||
}];
|
||||
|
||||
self.chat_completion_with_retry(messages).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let mut messages = vec![];
|
||||
|
||||
if !system.is_empty() {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: system.to_string(),
|
||||
reasoning_content: None,
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
reasoning_content: None,
|
||||
});
|
||||
|
||||
self.chat_completion_with_retry(messages).await
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
self.validate_key().await.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"kimi"
|
||||
}
|
||||
}
|
||||
|
||||
impl KimiClient {
|
||||
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let mut last_error = None;
|
||||
|
||||
for attempt in 1..=3 {
|
||||
match self.chat_completion(messages.clone()).await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(e) => {
|
||||
let err_msg = e.to_string();
|
||||
let is_retryable = err_msg.contains("timeout")
|
||||
|| err_msg.contains("connection")
|
||||
|| err_msg.contains("temporary")
|
||||
|| err_msg.contains("5")
|
||||
&& (err_msg.contains("500")
|
||||
|| err_msg.contains("502")
|
||||
|| err_msg.contains("503")
|
||||
|| err_msg.contains("504"));
|
||||
|
||||
if !is_retryable || attempt == 3 {
|
||||
last_error = Some(e);
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
let thinking = if self.thinking_enabled {
|
||||
Some(ThinkingConfig {
|
||||
thinking_type: "enabled".to_string(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// 对于 kimi-k2.6 等支持思考模式的模型,使用默认 temperature 即可
|
||||
// 思考模式下不显式指定 temperature
|
||||
let temperature = if self.thinking_enabled {
|
||||
None
|
||||
} else {
|
||||
Some(self.temperature)
|
||||
};
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages: messages.clone(),
|
||||
max_tokens: Some(self.max_tokens),
|
||||
temperature,
|
||||
stream: self.thinking_enabled,
|
||||
thinking,
|
||||
};
|
||||
|
||||
if self.thinking_enabled {
|
||||
self.streaming_chat_completion(&url, &request).await
|
||||
} else {
|
||||
self.non_streaming_chat_completion(&url, &request).await
|
||||
}
|
||||
}
|
||||
|
||||
/// 非流式请求(非思考模式)
|
||||
async fn non_streaming_chat_completion(
|
||||
&self,
|
||||
url: &str,
|
||||
request: &ChatCompletionRequest,
|
||||
) -> Result<String> {
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to Kimi")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!(
|
||||
"Kimi API error: {} ({})",
|
||||
error.error.message,
|
||||
error.error.error_type
|
||||
);
|
||||
}
|
||||
|
||||
bail!("Kimi API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let result: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Kimi response")?;
|
||||
|
||||
result
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Kimi"))
|
||||
}
|
||||
|
||||
/// 流式请求(思考模式),处理 reasoning_content 和 content
|
||||
async fn streaming_chat_completion(
|
||||
&self,
|
||||
url: &str,
|
||||
request: &ChatCompletionRequest,
|
||||
) -> Result<String> {
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "text/event-stream")
|
||||
.json(request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send streaming request to Kimi")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!(
|
||||
"Kimi API error: {} ({})",
|
||||
error.error.message,
|
||||
error.error.error_type
|
||||
);
|
||||
}
|
||||
|
||||
bail!("Kimi API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let mut content_buffer = String::new();
|
||||
let mut has_reasoning = false;
|
||||
let mut has_content = false;
|
||||
let mut stream_ended = false;
|
||||
|
||||
let thinking_state = self.thinking_state.as_ref();
|
||||
|
||||
let mut byte_stream = response.bytes_stream();
|
||||
let mut line_buffer = String::new();
|
||||
|
||||
use futures_util::StreamExt;
|
||||
|
||||
while let Some(chunk) = byte_stream.next().await {
|
||||
let chunk = chunk.context("Failed to read streaming response chunk")?;
|
||||
let chunk_str =
|
||||
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
|
||||
|
||||
line_buffer.push_str(&chunk_str);
|
||||
|
||||
while let Some(line_end) = line_buffer.find('\n') {
|
||||
let line = line_buffer[..line_end].trim().to_string();
|
||||
line_buffer = line_buffer[line_end + 1..].to_string();
|
||||
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if line == "data: [DONE]" {
|
||||
stream_ended = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(json_str) = line.strip_prefix("data: ") {
|
||||
match serde_json::from_str::<StreamChunk>(json_str) {
|
||||
Ok(chunk) => {
|
||||
for choice in &chunk.choices {
|
||||
if let Some(ref reasoning) = choice.delta.reasoning_content
|
||||
&& !reasoning.is_empty() {
|
||||
if !has_reasoning {
|
||||
has_reasoning = true;
|
||||
if let Some(state) = thinking_state {
|
||||
state.start_thinking();
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(ref content) = choice.delta.content
|
||||
&& !content.is_empty() {
|
||||
if has_reasoning && !has_content
|
||||
&& let Some(state) = thinking_state {
|
||||
state.end_thinking();
|
||||
}
|
||||
has_content = true;
|
||||
content_buffer.push_str(content);
|
||||
}
|
||||
|
||||
if let Some(ref reason) = choice.finish_reason
|
||||
&& reason == "stop" {
|
||||
stream_ended = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// 忽略无法解析的行
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if stream_ended {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 确保思考状态已结束
|
||||
if let Some(state) = thinking_state {
|
||||
state.end_thinking();
|
||||
}
|
||||
|
||||
let result = content_buffer.trim().to_string();
|
||||
|
||||
if result.is_empty() {
|
||||
if has_reasoning && !has_content {
|
||||
bail!(
|
||||
"Kimi returned reasoning content but no final answer. \
|
||||
The model may have entered an incomplete thinking state. \
|
||||
Please try again or disable thinking mode."
|
||||
);
|
||||
}
|
||||
bail!(
|
||||
"No response from Kimi. \
|
||||
If thinking mode is enabled, try disabling it or ensure the model supports it."
|
||||
);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// 可用 Kimi 模型列表
|
||||
pub const KIMI_MODELS: &[&str] = &[
|
||||
// K2 系列(推荐)
|
||||
"kimi-k2.6",
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-thinking-turbo",
|
||||
"kimi-k2-instruct",
|
||||
"kimi-k2-instruct-0905",
|
||||
// 兼容旧版模型 ID
|
||||
"moonshot-v1-8k",
|
||||
"moonshot-v1-32k",
|
||||
"moonshot-v1-128k",
|
||||
];
|
||||
|
||||
pub fn is_valid_model(model: &str) -> bool {
|
||||
KIMI_MODELS.contains(&model)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_validation_k2() {
|
||||
assert!(is_valid_model("kimi-k2.6"));
|
||||
assert!(is_valid_model("kimi-k2.5"));
|
||||
assert!(is_valid_model("kimi-k2-thinking"));
|
||||
assert!(is_valid_model("kimi-k2-thinking-turbo"));
|
||||
assert!(is_valid_model("moonshot-v1-8k"));
|
||||
assert!(is_valid_model("moonshot-v1-32k"));
|
||||
assert!(is_valid_model("moonshot-v1-128k"));
|
||||
assert!(!is_valid_model("invalid-model"));
|
||||
assert!(!is_valid_model("kimi-k1.5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_builder_defaults() {
|
||||
let client = KimiClient::new("test-key", "kimi-k2.6").unwrap();
|
||||
assert!(!client.thinking_enabled);
|
||||
assert_eq!(client.max_tokens, 500);
|
||||
assert_eq!(client.temperature, 1.0);
|
||||
assert!(client.thinking_state.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_builder_with_thinking() {
|
||||
let client = KimiClient::new("test-key", "kimi-k2.6")
|
||||
.unwrap()
|
||||
.with_thinking(true)
|
||||
.with_max_tokens(1000)
|
||||
.with_temperature(0.5);
|
||||
|
||||
assert!(client.thinking_enabled);
|
||||
assert_eq!(client.max_tokens, 1000);
|
||||
assert_eq!(client.temperature, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thinking_config_serialization() {
|
||||
let config = ThinkingConfig {
|
||||
thinking_type: "enabled".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
assert_eq!(json, r#"{"type":"enabled"}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_new_defaults() {
|
||||
let client = KimiClient::new("test-key", "kimi-k2.6").unwrap();
|
||||
assert_eq!(client.name(), "kimi");
|
||||
assert!(!client.thinking_enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_serialization() {
|
||||
let msg = Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hello".to_string(),
|
||||
reasoning_content: None,
|
||||
};
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
assert!(!json.contains("reasoning_content"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ pub mod anthropic;
|
||||
pub mod kimi;
|
||||
pub mod deepseek;
|
||||
pub mod openrouter;
|
||||
pub mod thinking;
|
||||
|
||||
pub use ollama::OllamaClient;
|
||||
pub use openai::OpenAiClient;
|
||||
@@ -61,48 +62,114 @@ impl Default for LlmClientConfig {
|
||||
impl LlmClient {
|
||||
/// Create LLM client from configuration manager
|
||||
pub async fn from_config(manager: &crate::config::manager::ConfigManager) -> Result<Self> {
|
||||
Self::from_config_with_think(manager, manager.config().llm.thinking_enabled).await
|
||||
}
|
||||
|
||||
/// Create LLM client from configuration with explicit thinking override
|
||||
pub async fn from_config_with_think(
|
||||
manager: &crate::config::manager::ConfigManager,
|
||||
thinking_enabled: bool,
|
||||
) -> Result<Self> {
|
||||
let config = manager.config();
|
||||
let client_config = LlmClientConfig {
|
||||
max_tokens: config.llm.max_tokens,
|
||||
temperature: config.llm.temperature,
|
||||
timeout: Duration::from_secs(config.llm.timeout),
|
||||
thinking_enabled: config.llm.thinking_enabled,
|
||||
thinking_enabled,
|
||||
};
|
||||
|
||||
let provider = config.llm.provider.as_str();
|
||||
let model = config.llm.model.as_str();
|
||||
let base_url = manager.llm_base_url();
|
||||
let api_key = manager.get_api_key();
|
||||
let thinking_enabled = config.llm.thinking_enabled;
|
||||
|
||||
let provider: Box<dyn LlmProvider> = match provider {
|
||||
"ollama" => {
|
||||
Box::new(OllamaClient::new(&base_url, model))
|
||||
Box::new(OllamaClient::new(&base_url, model)
|
||||
.with_max_tokens(client_config.max_tokens)
|
||||
.with_temperature(client_config.temperature))
|
||||
}
|
||||
"openai" => {
|
||||
let key = api_key.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenAI API key not configured"))?;
|
||||
Box::new(OpenAiClient::new(&base_url, key, model)?)
|
||||
let thinking_state = if thinking_enabled {
|
||||
Some(thinking::create_console_thinking_state())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut client = OpenAiClient::new(&base_url, key, model)?
|
||||
.with_thinking(thinking_enabled)
|
||||
.with_max_tokens(client_config.max_tokens)
|
||||
.with_temperature(client_config.temperature)
|
||||
.with_timeout(client_config.timeout)?;
|
||||
if let Some(state) = thinking_state {
|
||||
client = client.with_thinking_state(state);
|
||||
}
|
||||
Box::new(client)
|
||||
}
|
||||
"anthropic" => {
|
||||
let key = api_key.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Anthropic API key not configured"))?;
|
||||
Box::new(AnthropicClient::new(key, model)?)
|
||||
let thinking_state = if thinking_enabled {
|
||||
Some(thinking::create_console_thinking_state())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let budget = config.llm.thinking_budget_tokens.unwrap_or(1024);
|
||||
let mut client = AnthropicClient::new(key, model)?
|
||||
.with_thinking(thinking_enabled)
|
||||
.with_thinking_budget_tokens(budget)
|
||||
.with_max_tokens(client_config.max_tokens)
|
||||
.with_temperature(client_config.temperature)
|
||||
.with_timeout(client_config.timeout)?;
|
||||
if let Some(state) = thinking_state {
|
||||
client = client.with_thinking_state(state);
|
||||
}
|
||||
Box::new(client)
|
||||
}
|
||||
"kimi" => {
|
||||
let key = api_key.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Kimi API key not configured"))?;
|
||||
Box::new(KimiClient::with_base_url(key, model, &base_url)?.with_thinking(thinking_enabled))
|
||||
let thinking_state = if thinking_enabled {
|
||||
Some(thinking::create_console_thinking_state())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut client = KimiClient::with_base_url(key, model, &base_url)?
|
||||
.with_thinking(thinking_enabled)
|
||||
.with_max_tokens(client_config.max_tokens)
|
||||
.with_temperature(client_config.temperature)
|
||||
.with_timeout(client_config.timeout)?;
|
||||
if let Some(state) = thinking_state {
|
||||
client = client.with_thinking_state(state);
|
||||
}
|
||||
Box::new(client)
|
||||
}
|
||||
"deepseek" => {
|
||||
let key = api_key.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("DeepSeek API key not configured"))?;
|
||||
Box::new(DeepSeekClient::with_base_url(key, model, &base_url)?.with_thinking(thinking_enabled))
|
||||
let thinking_state = if thinking_enabled {
|
||||
Some(thinking::create_console_thinking_state())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut client = DeepSeekClient::with_base_url(key, model, &base_url)?
|
||||
.with_thinking(thinking_enabled)
|
||||
.with_max_tokens(client_config.max_tokens)
|
||||
.with_temperature(client_config.temperature)
|
||||
.with_timeout(client_config.timeout)?;
|
||||
if let Some(state) = thinking_state {
|
||||
client = client.with_thinking_state(state);
|
||||
}
|
||||
Box::new(client)
|
||||
}
|
||||
"openrouter" => {
|
||||
let key = api_key.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not configured"))?;
|
||||
Box::new(OpenRouterClient::with_base_url(key, model, &base_url)?)
|
||||
Box::new(OpenRouterClient::with_base_url(key, model, &base_url)?
|
||||
.with_max_tokens(client_config.max_tokens)
|
||||
.with_temperature(client_config.temperature)
|
||||
.with_timeout(client_config.timeout)?)
|
||||
}
|
||||
_ => bail!("Unknown LLM provider: {}", provider),
|
||||
};
|
||||
|
||||
@@ -9,6 +9,9 @@ pub struct OllamaClient {
|
||||
base_url: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
top_p: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -49,11 +52,14 @@ impl OllamaClient {
|
||||
pub fn new(base_url: &str, model: &str) -> Self {
|
||||
let client = create_http_client(Duration::from_secs(120))
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
|
||||
Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
top_p: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,6 +70,21 @@ impl OllamaClient {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.max_tokens = max_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||
self.temperature = temperature;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||
self.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
/// List available models
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
@@ -143,8 +164,8 @@ impl LlmProvider for OllamaClient {
|
||||
system,
|
||||
stream: false,
|
||||
options: GenerationOptions {
|
||||
temperature: Some(0.7),
|
||||
num_predict: Some(500),
|
||||
temperature: Some(self.temperature),
|
||||
num_predict: Some(self.max_tokens),
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
use super::thinking::ThinkingStateManager;
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// OpenAI API client
|
||||
/// OpenAI API client with o-series reasoning support
|
||||
pub struct OpenAiClient {
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
thinking_enabled: bool,
|
||||
reasoning_effort: Option<String>,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
top_p: Option<f32>,
|
||||
thinking_state: Option<Arc<ThinkingStateManager>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -20,10 +28,14 @@ struct ChatCompletionRequest {
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
reasoning_effort: Option<String>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
@@ -39,6 +51,28 @@ struct Choice {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
// --- Streaming response structures ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StreamChunk {
|
||||
choices: Vec<StreamChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StreamChoice {
|
||||
delta: StreamDelta,
|
||||
#[serde(default)]
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct StreamDelta {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: ApiError,
|
||||
@@ -55,57 +89,91 @@ impl OpenAiClient {
|
||||
/// Create new OpenAI client
|
||||
pub fn new(base_url: &str, api_key: &str, model: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
thinking_enabled: false,
|
||||
reasoning_effort: None,
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
top_p: None,
|
||||
thinking_state: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set timeout
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||
self.client = create_http_client(timeout)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// List available models
|
||||
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
||||
self.thinking_enabled = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_reasoning_effort(mut self, effort: Option<String>) -> Self {
|
||||
self.reasoning_effort = effort;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.max_tokens = max_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||
self.temperature = temperature;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||
self.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
|
||||
self.thinking_state = Some(state);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
let url = format!("{}/models", self.base_url);
|
||||
|
||||
let response = self.client
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to list OpenAI models")?;
|
||||
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
bail!("OpenAI API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModelsResponse {
|
||||
data: Vec<Model>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Model {
|
||||
id: String,
|
||||
}
|
||||
|
||||
|
||||
let result: ModelsResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse OpenAI response")?;
|
||||
|
||||
|
||||
Ok(result.data.into_iter().map(|m| m.id).collect())
|
||||
}
|
||||
|
||||
/// Validate API key
|
||||
pub async fn validate_key(&self) -> Result<bool> {
|
||||
match self.list_models().await {
|
||||
Ok(_) => Ok(true),
|
||||
@@ -124,32 +192,30 @@ impl OpenAiClient {
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenAiClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
self.chat_completion(messages).await
|
||||
let messages = vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
}];
|
||||
|
||||
self.chat_completion_with_retry(messages).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let mut messages = vec![];
|
||||
|
||||
|
||||
if !system.is_empty() {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: system.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
});
|
||||
|
||||
self.chat_completion(messages).await
|
||||
|
||||
self.chat_completion_with_retry(messages).await
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
@@ -162,18 +228,59 @@ impl LlmProvider for OpenAiClient {
|
||||
}
|
||||
|
||||
impl OpenAiClient {
|
||||
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let mut last_error = None;
|
||||
|
||||
for attempt in 1..=3 {
|
||||
match self.chat_completion(messages.clone()).await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(e) => {
|
||||
let err_msg = e.to_string();
|
||||
let is_retryable = err_msg.contains("timeout")
|
||||
|| err_msg.contains("connection")
|
||||
|| err_msg.contains("temporary")
|
||||
|| err_msg.contains("5")
|
||||
&& (err_msg.contains("500")
|
||||
|| err_msg.contains("502")
|
||||
|| err_msg.contains("503")
|
||||
|| err_msg.contains("504"));
|
||||
|
||||
if !is_retryable || attempt == 3 {
|
||||
last_error = Some(e);
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||
if self.thinking_enabled {
|
||||
self.streaming_chat_completion(messages).await
|
||||
} else {
|
||||
self.non_streaming_chat_completion(messages).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn non_streaming_chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
max_tokens: Some(500),
|
||||
temperature: Some(0.7),
|
||||
max_tokens: Some(self.max_tokens),
|
||||
temperature: Some(self.temperature),
|
||||
top_p: self.top_p,
|
||||
reasoning_effort: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
@@ -181,31 +288,165 @@ impl OpenAiClient {
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to OpenAI")?;
|
||||
|
||||
|
||||
let status = response.status();
|
||||
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
// Try to parse error
|
||||
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!("OpenAI API error: {} ({})", error.error.message, error.error.error_type);
|
||||
bail!(
|
||||
"OpenAI API error: {} ({})",
|
||||
error.error.message,
|
||||
error.error.error_type
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
bail!("OpenAI API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
|
||||
let result: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse OpenAI response")?;
|
||||
|
||||
result.choices
|
||||
|
||||
result
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
|
||||
}
|
||||
|
||||
/// Streaming request for reasoning mode, filters reasoning_content from output
|
||||
async fn streaming_chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
// For reasoning/thinking mode, omit temperature and top_p
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
max_tokens: Some(self.max_tokens),
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
reasoning_effort: self.reasoning_effort.clone(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "text/event-stream")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send streaming request to OpenAI")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!(
|
||||
"OpenAI API error: {} ({})",
|
||||
error.error.message,
|
||||
error.error.error_type
|
||||
);
|
||||
}
|
||||
|
||||
bail!("OpenAI API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let mut content_buffer = String::new();
|
||||
let mut has_reasoning = false;
|
||||
let mut has_content = false;
|
||||
|
||||
let thinking_state = self.thinking_state.as_ref();
|
||||
|
||||
let mut byte_stream = response.bytes_stream();
|
||||
let mut line_buffer = String::new();
|
||||
|
||||
use futures_util::StreamExt;
|
||||
|
||||
while let Some(chunk) = byte_stream.next().await {
|
||||
let chunk = chunk.context("Failed to read streaming response chunk")?;
|
||||
let chunk_str =
|
||||
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
|
||||
|
||||
line_buffer.push_str(&chunk_str);
|
||||
|
||||
while let Some(line_end) = line_buffer.find('\n') {
|
||||
let line = line_buffer[..line_end].trim().to_string();
|
||||
line_buffer = line_buffer[line_end + 1..].to_string();
|
||||
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if line == "data: [DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(json_str) = line.strip_prefix("data: ") {
|
||||
if let Ok(chunk) = serde_json::from_str::<StreamChunk>(json_str) {
|
||||
for choice in &chunk.choices {
|
||||
// Handle reasoning_content (o-series)
|
||||
if let Some(ref reasoning) = choice.delta.reasoning_content
|
||||
&& !reasoning.is_empty()
|
||||
{
|
||||
if !has_reasoning {
|
||||
has_reasoning = true;
|
||||
if let Some(state) = thinking_state {
|
||||
state.start_thinking();
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle content
|
||||
if let Some(ref content) = choice.delta.content
|
||||
&& !content.is_empty()
|
||||
{
|
||||
if has_reasoning && !has_content
|
||||
&& let Some(state) = thinking_state
|
||||
{
|
||||
state.end_thinking();
|
||||
}
|
||||
has_content = true;
|
||||
content_buffer.push_str(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(state) = thinking_state {
|
||||
state.end_thinking();
|
||||
}
|
||||
|
||||
let result = content_buffer.trim().to_string();
|
||||
|
||||
if result.is_empty() {
|
||||
if has_reasoning && !has_content {
|
||||
bail!(
|
||||
"OpenAI returned reasoning content but no final answer. \
|
||||
The model may have entered an incomplete reasoning state. \
|
||||
Please try again or disable thinking mode."
|
||||
);
|
||||
}
|
||||
bail!(
|
||||
"No response from OpenAI. \
|
||||
If thinking mode is enabled, try disabling it or ensure the model supports reasoning."
|
||||
);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Azure OpenAI client (extends OpenAI with Azure-specific config)
|
||||
@@ -215,10 +456,15 @@ pub struct AzureOpenAiClient {
|
||||
deployment: String,
|
||||
api_version: String,
|
||||
client: reqwest::Client,
|
||||
thinking_enabled: bool,
|
||||
reasoning_effort: Option<String>,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
top_p: Option<f32>,
|
||||
thinking_state: Option<Arc<ThinkingStateManager>>,
|
||||
}
|
||||
|
||||
impl AzureOpenAiClient {
|
||||
/// Create new Azure OpenAI client
|
||||
pub fn new(
|
||||
endpoint: &str,
|
||||
api_key: &str,
|
||||
@@ -226,13 +472,19 @@ impl AzureOpenAiClient {
|
||||
api_version: &str,
|
||||
) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
endpoint: endpoint.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
deployment: deployment.to_string(),
|
||||
api_version: api_version.to_string(),
|
||||
client,
|
||||
thinking_enabled: false,
|
||||
reasoning_effort: None,
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
top_p: None,
|
||||
thinking_state: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -241,16 +493,19 @@ impl AzureOpenAiClient {
|
||||
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
||||
self.endpoint, self.deployment, self.api_version
|
||||
);
|
||||
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.deployment.clone(),
|
||||
messages,
|
||||
max_tokens: Some(500),
|
||||
temperature: Some(0.7),
|
||||
max_tokens: Some(self.max_tokens),
|
||||
temperature: Some(self.temperature),
|
||||
top_p: self.top_p,
|
||||
reasoning_effort: self.reasoning_effort.clone(),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("api-key", &self.api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
@@ -258,22 +513,24 @@ impl AzureOpenAiClient {
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to Azure OpenAI")?;
|
||||
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
bail!("Azure OpenAI API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
|
||||
let result: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Azure OpenAI response")?;
|
||||
|
||||
result.choices
|
||||
|
||||
result
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))
|
||||
}
|
||||
}
|
||||
@@ -281,41 +538,38 @@ impl AzureOpenAiClient {
|
||||
#[async_trait]
|
||||
impl LlmProvider for AzureOpenAiClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let messages = vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
}];
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let mut messages = vec![];
|
||||
|
||||
|
||||
if !system.is_empty() {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: system.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
});
|
||||
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
// Simple check - try to make a minimal request
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
||||
self.endpoint, self.deployment, self.api_version
|
||||
);
|
||||
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.deployment.clone(),
|
||||
messages: vec![Message {
|
||||
@@ -324,10 +578,13 @@ impl LlmProvider for AzureOpenAiClient {
|
||||
}],
|
||||
max_tokens: Some(5),
|
||||
temperature: Some(0.0),
|
||||
top_p: None,
|
||||
reasoning_effort: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
match self.client
|
||||
|
||||
match self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("api-key", &self.api_key)
|
||||
.json(&request)
|
||||
@@ -343,3 +600,59 @@ impl LlmProvider for AzureOpenAiClient {
|
||||
"azure-openai"
|
||||
}
|
||||
}
|
||||
|
||||
/// Available OpenAI models (including o-series with reasoning)
|
||||
pub const OPENAI_MODELS: &[&str] = &[
|
||||
"o4-mini",
|
||||
"o3",
|
||||
"o3-mini",
|
||||
"o1",
|
||||
"o1-mini",
|
||||
"o1-pro",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4",
|
||||
"gpt-3.5-turbo",
|
||||
];
|
||||
|
||||
pub fn is_valid_model(model: &str) -> bool {
|
||||
OPENAI_MODELS.contains(&model)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_validation_o_series() {
|
||||
assert!(is_valid_model("o4-mini"));
|
||||
assert!(is_valid_model("o3"));
|
||||
assert!(is_valid_model("o1"));
|
||||
assert!(is_valid_model("gpt-4o"));
|
||||
assert!(is_valid_model("gpt-3.5-turbo"));
|
||||
assert!(!is_valid_model("invalid-model"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_delta_reasoning_parsing() {
|
||||
let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#;
|
||||
let delta: StreamDelta = serde_json::from_str(json).unwrap();
|
||||
assert!(delta.content.is_none());
|
||||
assert_eq!(
|
||||
delta.reasoning_content,
|
||||
Some("Let me think...".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stream_delta_content_parsing() {
|
||||
let json = r#"{"content":"Hello","reasoning_content":null}"#;
|
||||
let delta: StreamDelta = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(delta.content, Some("Hello".to_string()));
|
||||
assert!(delta.reasoning_content.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,9 @@ pub struct OpenRouterClient {
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
top_p: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -55,24 +58,30 @@ impl OpenRouterClient {
|
||||
/// Create new OpenRouter client
|
||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
base_url: "https://openrouter.ai/api/v1".to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
top_p: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with custom base URL
|
||||
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
top_p: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -82,6 +91,21 @@ impl OpenRouterClient {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.max_tokens = max_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||||
self.temperature = temperature;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
||||
self.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
/// List available models
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
let url = format!("{}/models", self.base_url);
|
||||
@@ -182,8 +206,8 @@ impl OpenRouterClient {
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
max_tokens: Some(500),
|
||||
temperature: Some(0.7),
|
||||
max_tokens: Some(self.max_tokens),
|
||||
temperature: Some(self.temperature),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
|
||||
152
src/llm/thinking.rs
Normal file
152
src/llm/thinking.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// 统一的思考状态管理器,用于管理模型思考状态的显示与隐藏
|
||||
pub struct ThinkingStateManager {
|
||||
is_thinking: AtomicBool,
|
||||
on_start: Option<Box<dyn Fn() + Send + Sync>>,
|
||||
on_end: Option<Box<dyn Fn() + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl ThinkingStateManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
is_thinking: AtomicBool::new(false),
|
||||
on_start: None,
|
||||
on_end: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 设置思考开始回调
|
||||
pub fn on_thinking_start<F: Fn() + Send + Sync + 'static>(mut self, callback: F) -> Self {
|
||||
self.on_start = Some(Box::new(callback));
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置思考结束回调
|
||||
pub fn on_thinking_end<F: Fn() + Send + Sync + 'static>(mut self, callback: F) -> Self {
|
||||
self.on_end = Some(Box::new(callback));
|
||||
self
|
||||
}
|
||||
|
||||
/// 开始思考状态
|
||||
pub fn start_thinking(&self) {
|
||||
if !self.is_thinking.load(Ordering::SeqCst) {
|
||||
self.is_thinking.store(true, Ordering::SeqCst);
|
||||
if let Some(ref cb) = self.on_start {
|
||||
cb();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 结束思考状态
|
||||
pub fn end_thinking(&self) {
|
||||
if self.is_thinking.load(Ordering::SeqCst) {
|
||||
self.is_thinking.store(false, Ordering::SeqCst);
|
||||
if let Some(ref cb) = self.on_end {
|
||||
cb();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 当前是否处于思考状态
|
||||
pub fn is_thinking(&self) -> bool {
|
||||
self.is_thinking.load(Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ThinkingStateManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// 线程安全的思考状态管理器引用
|
||||
pub type SharedThinkingState = Arc<ThinkingStateManager>;
|
||||
|
||||
/// 创建带有默认控制台输出的思考状态管理器
|
||||
/// 在思考开始时打印 "thinking...",在思考结束时清除该标识
|
||||
pub fn create_console_thinking_state() -> SharedThinkingState {
|
||||
Arc::new(
|
||||
ThinkingStateManager::new()
|
||||
.on_thinking_start(|| {
|
||||
eprint!("\rthinking...");
|
||||
})
|
||||
.on_thinking_end(|| {
|
||||
eprint!("\r \r");
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Mutex;
|
||||
|
||||
#[test]
|
||||
fn test_thinking_state_transitions() {
|
||||
let manager = ThinkingStateManager::new();
|
||||
assert!(!manager.is_thinking());
|
||||
|
||||
manager.start_thinking();
|
||||
assert!(manager.is_thinking());
|
||||
|
||||
manager.end_thinking();
|
||||
assert!(!manager.is_thinking());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thinking_idempotent_start() {
|
||||
let manager = ThinkingStateManager::new();
|
||||
manager.start_thinking();
|
||||
manager.start_thinking(); // 重复调用不应触发回调两次
|
||||
assert!(manager.is_thinking());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thinking_idempotent_end() {
|
||||
let manager = ThinkingStateManager::new();
|
||||
manager.end_thinking(); // 未开始时结束不应触发问题
|
||||
assert!(!manager.is_thinking());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thinking_callbacks() {
|
||||
let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
|
||||
let events_clone = events.clone();
|
||||
|
||||
let manager = ThinkingStateManager::new()
|
||||
.on_thinking_start(move || {
|
||||
events_clone.lock().unwrap().push("start".to_string());
|
||||
});
|
||||
|
||||
let events_clone2 = events.clone();
|
||||
let manager = manager.on_thinking_end(move || {
|
||||
events_clone2.lock().unwrap().push("end".to_string());
|
||||
});
|
||||
|
||||
manager.start_thinking();
|
||||
manager.end_thinking();
|
||||
|
||||
let recorded = events.lock().unwrap();
|
||||
assert_eq!(recorded.len(), 2);
|
||||
assert_eq!(recorded[0], "start");
|
||||
assert_eq!(recorded[1], "end");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_console_thinking_state() {
|
||||
let state = create_console_thinking_state();
|
||||
assert!(!state.is_thinking());
|
||||
state.start_thinking();
|
||||
assert!(state.is_thinking());
|
||||
state.end_thinking();
|
||||
assert!(!state.is_thinking());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default() {
|
||||
let manager = ThinkingStateManager::default();
|
||||
assert!(!manager.is_thinking());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user