use super::thinking::ThinkingStateManager; use super::{LlmProvider, create_http_client}; use anyhow::{Context, Result, bail}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::sync::Arc; use std::time::Duration; /// DeepSeek API client pub struct DeepSeekClient { base_url: String, api_key: String, model: String, client: reqwest::Client, thinking_enabled: bool, reasoning_effort: Option, max_tokens: u32, temperature: f32, thinking_state: Option>, } #[derive(Debug, Serialize)] struct ChatCompletionRequest { model: String, messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] presence_penalty: Option, #[serde(skip_serializing_if = "Option::is_none")] frequency_penalty: Option, stream: bool, #[serde(skip_serializing_if = "Option::is_none")] thinking: Option, #[serde(skip_serializing_if = "Option::is_none")] reasoning_effort: Option, } #[derive(Debug, Serialize)] struct ThinkingConfig { #[serde(rename = "type")] thinking_type: String, } #[derive(Debug, Clone, Serialize, Deserialize)] struct Message { role: String, content: String, #[serde(skip_serializing_if = "Option::is_none")] reasoning_content: Option, } #[derive(Debug, Deserialize)] struct ChatCompletionResponse { choices: Vec, } #[derive(Debug, Deserialize)] struct Choice { message: Message, #[serde(default)] reasoning_content: Option, } // --- Streaming response structures --- #[derive(Debug, Deserialize)] struct StreamChunk { choices: Vec, } #[derive(Debug, Deserialize)] struct StreamChoice { delta: StreamDelta, #[serde(default)] finish_reason: Option, index: Option, } #[derive(Debug, Deserialize, Default)] struct StreamDelta { #[serde(default)] content: Option, #[serde(default)] reasoning_content: Option, } #[derive(Debug, Deserialize)] struct ErrorResponse { error: ApiError, } #[derive(Debug, Deserialize)] struct ApiError { message: String, #[serde(rename = "type")] error_type: String, } impl DeepSeekClient { pub fn new(api_key: &str, model: &str) -> Result { let client = create_http_client(Duration::from_secs(300))?; Ok(Self { base_url: "https://api.deepseek.com".to_string(), api_key: api_key.to_string(), model: model.to_string(), client, thinking_enabled: false, reasoning_effort: None, max_tokens: 500, temperature: 0.7, thinking_state: None, }) } pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result { let client = create_http_client(Duration::from_secs(300))?; Ok(Self { base_url: base_url.trim_end_matches('/').to_string(), api_key: api_key.to_string(), model: model.to_string(), client, thinking_enabled: false, reasoning_effort: None, max_tokens: 500, temperature: 0.7, thinking_state: None, }) } pub fn with_timeout(mut self, timeout: Duration) -> Result { self.client = create_http_client(timeout)?; Ok(self) } pub fn with_thinking(mut self, enabled: bool) -> Self { self.thinking_enabled = enabled; self } pub fn with_reasoning_effort(mut self, effort: Option) -> Self { self.reasoning_effort = effort; self } pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { self.max_tokens = max_tokens; self } pub fn with_temperature(mut self, temperature: f32) -> Self { self.temperature = temperature; self } pub fn with_thinking_state(mut self, state: Arc) -> Self { self.thinking_state = Some(state); self } pub async fn list_models(&self) -> Result> { let url = format!("{}/models", self.base_url); let response = self .client .get(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .send() .await .context("Failed to list DeepSeek models")?; if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); bail!("DeepSeek API error: {} - {}", status, text); } #[derive(Deserialize)] struct ModelsResponse { data: Vec, } #[derive(Deserialize)] struct ModelId { id: String, } let result: ModelsResponse = response .json() .await .context("Failed to parse DeepSeek response")?; Ok(result.data.into_iter().map(|m| m.id).collect()) } pub async fn validate_key(&self) -> Result { match self.list_models().await { Ok(_) => Ok(true), Err(e) => { let err_str = e.to_string(); if err_str.contains("401") || err_str.contains("Unauthorized") { Ok(false) } else { Err(e) } } } } } #[async_trait] impl LlmProvider for DeepSeekClient { async fn generate(&self, prompt: &str) -> Result { let messages = vec![Message { role: "user".to_string(), content: prompt.to_string(), reasoning_content: None, }]; self.chat_completion_with_retry(messages).await } async fn generate_with_system(&self, system: &str, user: &str) -> Result { let mut messages = vec![]; if !system.is_empty() { messages.push(Message { role: "system".to_string(), content: system.to_string(), reasoning_content: None, }); } messages.push(Message { role: "user".to_string(), content: user.to_string(), reasoning_content: None, }); self.chat_completion_with_retry(messages).await } async fn is_available(&self) -> bool { self.validate_key().await.unwrap_or(false) } fn name(&self) -> &str { "deepseek" } } impl DeepSeekClient { async fn chat_completion_with_retry(&self, messages: Vec) -> Result { let mut last_error = None; for attempt in 1..=3 { match self.chat_completion(messages.clone()).await { Ok(result) => return Ok(result), Err(e) => { let err_msg = e.to_string(); // 网络临时错误才重试 let is_retryable = err_msg.contains("timeout") || err_msg.contains("connection") || err_msg.contains("temporary") || err_msg.contains("5") && (err_msg.contains("500") || err_msg.contains("502") || err_msg.contains("503") || err_msg.contains("504")); if !is_retryable || attempt == 3 { last_error = Some(e); break; } // 指数退避 tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await; } } } Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries"))) } async fn chat_completion(&self, messages: Vec) -> Result { let url = format!("{}/chat/completions", self.base_url); let thinking = if self.thinking_enabled { Some(ThinkingConfig { thinking_type: "enabled".to_string(), }) } else { None }; // 思考模式下,temperature/top_p 等参数不应传递 // 非思考模式下可以正常传递 let (temperature, max_tokens, top_p, presence_penalty, frequency_penalty) = if self.thinking_enabled { (None, Some(self.max_tokens), None, None, None) } else { ( Some(self.temperature), Some(self.max_tokens), None, None, None, ) }; let reasoning_effort = if self.thinking_enabled { self.reasoning_effort.clone() } else { None }; let request = ChatCompletionRequest { model: self.model.clone(), messages: messages.clone(), max_tokens, temperature, top_p, presence_penalty, frequency_penalty, stream: self.thinking_enabled, thinking, reasoning_effort, }; if self.thinking_enabled { self.streaming_chat_completion(&url, &request).await } else { self.non_streaming_chat_completion(&url, &request).await } } /// 非流式请求(非思考模式) async fn non_streaming_chat_completion( &self, url: &str, request: &ChatCompletionRequest, ) -> Result { let response = self .client .post(url) .header("Authorization", format!("Bearer {}", self.api_key)) .header("Content-Type", "application/json") .json(request) .send() .await .context("Failed to send request to DeepSeek")?; let status = response.status(); if !status.is_success() { let text = response.text().await.unwrap_or_default(); if let Ok(error) = serde_json::from_str::(&text) { bail!( "DeepSeek API error: {} ({})", error.error.message, error.error.error_type ); } bail!("DeepSeek API error: {} - {}", status, text); } let result: ChatCompletionResponse = response .json() .await .context("Failed to parse DeepSeek response")?; result .choices .into_iter() .next() .map(|c| c.message.content.trim().to_string()) .filter(|s| !s.is_empty()) .ok_or_else(|| anyhow::anyhow!("No response from DeepSeek")) } /// 流式请求(思考模式),处理 reasoning_content 和 content async fn streaming_chat_completion( &self, url: &str, request: &ChatCompletionRequest, ) -> Result { let response = self .client .post(url) .header("Authorization", format!("Bearer {}", self.api_key)) .header("Content-Type", "application/json") .header("Accept", "text/event-stream") .json(request) .send() .await .context("Failed to send streaming request to DeepSeek")?; let status = response.status(); if !status.is_success() { let text = response.text().await.unwrap_or_default(); if let Ok(error) = serde_json::from_str::(&text) { bail!( "DeepSeek API error: {} ({})", error.error.message, error.error.error_type ); } bail!("DeepSeek API error: {} - {}", status, text); } let mut content_buffer = String::new(); let mut has_reasoning = false; let mut has_content = false; let mut stream_ended = false; let thinking_state = self.thinking_state.as_ref(); let mut byte_stream = response.bytes_stream(); let mut line_buffer = String::new(); use futures_util::StreamExt; while let Some(chunk) = byte_stream.next().await { let chunk = chunk.context("Failed to read streaming response chunk")?; let chunk_str = String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?; line_buffer.push_str(&chunk_str); // 处理完整行 while let Some(line_end) = line_buffer.find('\n') { let line = line_buffer[..line_end].trim().to_string(); line_buffer = line_buffer[line_end + 1..].to_string(); if line.is_empty() { continue; } // SSE 格式:data: {...} 或 data: [DONE] if line == "data: [DONE]" { stream_ended = true; break; } if let Some(json_str) = line.strip_prefix("data: ") { match serde_json::from_str::(json_str) { Ok(chunk) => { for choice in &chunk.choices { // 处理 reasoning_content if let Some(ref reasoning) = choice.delta.reasoning_content && !reasoning.is_empty() { if !has_reasoning { has_reasoning = true; if let Some(state) = thinking_state { state.start_thinking(); } } // reasoning_content 不对外输出,仅用于内部状态判断 continue; } // 处理 content if let Some(ref content) = choice.delta.content && !content.is_empty() { // reasoning 结束,content 开始出现时移除 thinking 标识 if has_reasoning && !has_content && let Some(state) = thinking_state { state.end_thinking(); } has_content = true; content_buffer.push_str(content); } // 检查 finish_reason if let Some(ref reason) = choice.finish_reason && reason == "stop" { stream_ended = true; } } } Err(_) => { // 忽略无法解析的行(可能是心跳或注释) } } } } if stream_ended { break; } } // 确保思考状态已结束 if let Some(state) = thinking_state { state.end_thinking(); } let result = content_buffer.trim().to_string(); if result.is_empty() { if has_reasoning && !has_content { bail!( "DeepSeek returned reasoning content but no final answer. \ The model may have entered an incomplete thinking state. \ Please try again or disable thinking mode." ); } bail!( "No response from DeepSeek. \ If thinking mode is enabled, try disabling it or ensure the model supports it." ); } Ok(result) } } /// 可用 DeepSeek 模型列表 /// deepseek-chat / deepseek-reasoner 将于 2026-07-24 停用,推荐使用 V4 系列 pub const DEEPSEEK_MODELS: &[&str] = &[ "deepseek-v4-flash", "deepseek-v4-pro", // 兼容旧版模型 ID(将于 2026-07-24 停用) "deepseek-chat", "deepseek-reasoner", ]; pub fn is_valid_model(model: &str) -> bool { DEEPSEEK_MODELS.contains(&model) } #[cfg(test)] mod tests { use super::*; #[test] fn test_model_validation_v4() { assert!(is_valid_model("deepseek-v4-flash")); assert!(is_valid_model("deepseek-v4-pro")); assert!(is_valid_model("deepseek-chat")); assert!(is_valid_model("deepseek-reasoner")); assert!(!is_valid_model("invalid-model")); assert!(!is_valid_model("deepseek-v3")); } #[test] fn test_client_builder_defaults() { let client = DeepSeekClient::new("test-key", "deepseek-v4-flash").unwrap(); assert!(!client.thinking_enabled); assert_eq!(client.max_tokens, 500); assert_eq!(client.temperature, 0.7); assert!(client.reasoning_effort.is_none()); assert!(client.thinking_state.is_none()); } #[test] fn test_client_builder_with_thinking() { let client = DeepSeekClient::new("test-key", "deepseek-v4-flash") .unwrap() .with_thinking(true) .with_reasoning_effort(Some("high".to_string())) .with_max_tokens(1000) .with_temperature(0.5); assert!(client.thinking_enabled); assert_eq!(client.reasoning_effort, Some("high".to_string())); assert_eq!(client.max_tokens, 1000); assert_eq!(client.temperature, 0.5); } #[test] fn test_thinking_config_serialization() { let config = ThinkingConfig { thinking_type: "enabled".to_string(), }; let json = serde_json::to_string(&config).unwrap(); assert_eq!(json, r#"{"type":"enabled"}"#); } #[test] fn test_message_serialization_without_reasoning() { let msg = Message { role: "user".to_string(), content: "Hello".to_string(), reasoning_content: None, }; let json = serde_json::to_string(&msg).unwrap(); assert!(!json.contains("reasoning_content")); } #[test] fn test_stream_delta_parsing() { let json = r#"{"content":"Hello","reasoning_content":null}"#; let delta: StreamDelta = serde_json::from_str(json).unwrap(); assert_eq!(delta.content, Some("Hello".to_string())); assert!(delta.reasoning_content.is_none()); } #[test] fn test_stream_delta_reasoning_only() { let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#; let delta: StreamDelta = serde_json::from_str(json).unwrap(); assert!(delta.content.is_none()); assert_eq!(delta.reasoning_content, Some("Let me think...".to_string())); } }