623 lines
19 KiB
Rust
623 lines
19 KiB
Rust
use super::thinking::ThinkingStateManager;
|
||
use super::{LlmProvider, create_http_client};
|
||
use anyhow::{Context, Result, bail};
|
||
use async_trait::async_trait;
|
||
use serde::{Deserialize, Serialize};
|
||
use std::sync::Arc;
|
||
use std::time::Duration;
|
||
|
||
/// DeepSeek API client
|
||
pub struct DeepSeekClient {
|
||
base_url: String,
|
||
api_key: String,
|
||
model: String,
|
||
client: reqwest::Client,
|
||
thinking_enabled: bool,
|
||
reasoning_effort: Option<String>,
|
||
max_tokens: u32,
|
||
temperature: f32,
|
||
thinking_state: Option<Arc<ThinkingStateManager>>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
struct ChatCompletionRequest {
|
||
model: String,
|
||
messages: Vec<Message>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
max_tokens: Option<u32>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
temperature: Option<f32>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
top_p: Option<f32>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
presence_penalty: Option<f32>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
frequency_penalty: Option<f32>,
|
||
stream: bool,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
thinking: Option<ThinkingConfig>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
reasoning_effort: Option<String>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
struct ThinkingConfig {
|
||
#[serde(rename = "type")]
|
||
thinking_type: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
struct Message {
|
||
role: String,
|
||
content: String,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
reasoning_content: Option<String>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct ChatCompletionResponse {
|
||
choices: Vec<Choice>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct Choice {
|
||
message: Message,
|
||
#[serde(default)]
|
||
reasoning_content: Option<String>,
|
||
}
|
||
|
||
// --- Streaming response structures ---
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct StreamChunk {
|
||
choices: Vec<StreamChoice>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct StreamChoice {
|
||
delta: StreamDelta,
|
||
#[serde(default)]
|
||
finish_reason: Option<String>,
|
||
index: Option<u32>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize, Default)]
|
||
struct StreamDelta {
|
||
#[serde(default)]
|
||
content: Option<String>,
|
||
#[serde(default)]
|
||
reasoning_content: Option<String>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct ErrorResponse {
|
||
error: ApiError,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct ApiError {
|
||
message: String,
|
||
#[serde(rename = "type")]
|
||
error_type: String,
|
||
}
|
||
|
||
impl DeepSeekClient {
|
||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||
let client = create_http_client(Duration::from_secs(300))?;
|
||
|
||
Ok(Self {
|
||
base_url: "https://api.deepseek.com".to_string(),
|
||
api_key: api_key.to_string(),
|
||
model: model.to_string(),
|
||
client,
|
||
thinking_enabled: false,
|
||
reasoning_effort: None,
|
||
max_tokens: 500,
|
||
temperature: 0.7,
|
||
thinking_state: None,
|
||
})
|
||
}
|
||
|
||
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
||
let client = create_http_client(Duration::from_secs(300))?;
|
||
|
||
Ok(Self {
|
||
base_url: base_url.trim_end_matches('/').to_string(),
|
||
api_key: api_key.to_string(),
|
||
model: model.to_string(),
|
||
client,
|
||
thinking_enabled: false,
|
||
reasoning_effort: None,
|
||
max_tokens: 500,
|
||
temperature: 0.7,
|
||
thinking_state: None,
|
||
})
|
||
}
|
||
|
||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||
self.client = create_http_client(timeout)?;
|
||
Ok(self)
|
||
}
|
||
|
||
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
||
self.thinking_enabled = enabled;
|
||
self
|
||
}
|
||
|
||
pub fn with_reasoning_effort(mut self, effort: Option<String>) -> Self {
|
||
self.reasoning_effort = effort;
|
||
self
|
||
}
|
||
|
||
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
||
self.max_tokens = max_tokens;
|
||
self
|
||
}
|
||
|
||
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
||
self.temperature = temperature;
|
||
self
|
||
}
|
||
|
||
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
|
||
self.thinking_state = Some(state);
|
||
self
|
||
}
|
||
|
||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||
let url = format!("{}/models", self.base_url);
|
||
|
||
let response = self
|
||
.client
|
||
.get(&url)
|
||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||
.send()
|
||
.await
|
||
.context("Failed to list DeepSeek models")?;
|
||
|
||
if !response.status().is_success() {
|
||
let status = response.status();
|
||
let text = response.text().await.unwrap_or_default();
|
||
bail!("DeepSeek API error: {} - {}", status, text);
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct ModelsResponse {
|
||
data: Vec<ModelId>,
|
||
}
|
||
|
||
#[derive(Deserialize)]
|
||
struct ModelId {
|
||
id: String,
|
||
}
|
||
|
||
let result: ModelsResponse = response
|
||
.json()
|
||
.await
|
||
.context("Failed to parse DeepSeek response")?;
|
||
|
||
Ok(result.data.into_iter().map(|m| m.id).collect())
|
||
}
|
||
|
||
pub async fn validate_key(&self) -> Result<bool> {
|
||
match self.list_models().await {
|
||
Ok(_) => Ok(true),
|
||
Err(e) => {
|
||
let err_str = e.to_string();
|
||
if err_str.contains("401") || err_str.contains("Unauthorized") {
|
||
Ok(false)
|
||
} else {
|
||
Err(e)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl LlmProvider for DeepSeekClient {
|
||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||
let messages = vec![Message {
|
||
role: "user".to_string(),
|
||
content: prompt.to_string(),
|
||
reasoning_content: None,
|
||
}];
|
||
|
||
self.chat_completion_with_retry(messages).await
|
||
}
|
||
|
||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||
let mut messages = vec![];
|
||
|
||
if !system.is_empty() {
|
||
messages.push(Message {
|
||
role: "system".to_string(),
|
||
content: system.to_string(),
|
||
reasoning_content: None,
|
||
});
|
||
}
|
||
|
||
messages.push(Message {
|
||
role: "user".to_string(),
|
||
content: user.to_string(),
|
||
reasoning_content: None,
|
||
});
|
||
|
||
self.chat_completion_with_retry(messages).await
|
||
}
|
||
|
||
async fn is_available(&self) -> bool {
|
||
self.validate_key().await.unwrap_or(false)
|
||
}
|
||
|
||
fn name(&self) -> &str {
|
||
"deepseek"
|
||
}
|
||
}
|
||
|
||
impl DeepSeekClient {
|
||
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
|
||
let mut last_error = None;
|
||
|
||
for attempt in 1..=3 {
|
||
match self.chat_completion(messages.clone()).await {
|
||
Ok(result) => return Ok(result),
|
||
Err(e) => {
|
||
let err_msg = e.to_string();
|
||
// 网络临时错误才重试
|
||
let is_retryable = err_msg.contains("timeout")
|
||
|| err_msg.contains("connection")
|
||
|| err_msg.contains("temporary")
|
||
|| err_msg.contains("5")
|
||
&& (err_msg.contains("500")
|
||
|| err_msg.contains("502")
|
||
|| err_msg.contains("503")
|
||
|| err_msg.contains("504"));
|
||
|
||
if !is_retryable || attempt == 3 {
|
||
last_error = Some(e);
|
||
break;
|
||
}
|
||
|
||
// 指数退避
|
||
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
|
||
}
|
||
}
|
||
}
|
||
|
||
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
|
||
}
|
||
|
||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||
let url = format!("{}/chat/completions", self.base_url);
|
||
|
||
let thinking = if self.thinking_enabled {
|
||
Some(ThinkingConfig {
|
||
thinking_type: "enabled".to_string(),
|
||
})
|
||
} else {
|
||
None
|
||
};
|
||
|
||
// 思考模式下,temperature/top_p 等参数不应传递
|
||
// 非思考模式下可以正常传递
|
||
let (temperature, max_tokens, top_p, presence_penalty, frequency_penalty) =
|
||
if self.thinking_enabled {
|
||
(None, Some(self.max_tokens), None, None, None)
|
||
} else {
|
||
(
|
||
Some(self.temperature),
|
||
Some(self.max_tokens),
|
||
None,
|
||
None,
|
||
None,
|
||
)
|
||
};
|
||
|
||
let reasoning_effort = if self.thinking_enabled {
|
||
self.reasoning_effort.clone()
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let request = ChatCompletionRequest {
|
||
model: self.model.clone(),
|
||
messages: messages.clone(),
|
||
max_tokens,
|
||
temperature,
|
||
top_p,
|
||
presence_penalty,
|
||
frequency_penalty,
|
||
stream: self.thinking_enabled,
|
||
thinking,
|
||
reasoning_effort,
|
||
};
|
||
|
||
if self.thinking_enabled {
|
||
self.streaming_chat_completion(&url, &request).await
|
||
} else {
|
||
self.non_streaming_chat_completion(&url, &request).await
|
||
}
|
||
}
|
||
|
||
/// 非流式请求(非思考模式)
|
||
async fn non_streaming_chat_completion(
|
||
&self,
|
||
url: &str,
|
||
request: &ChatCompletionRequest,
|
||
) -> Result<String> {
|
||
let response = self
|
||
.client
|
||
.post(url)
|
||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||
.header("Content-Type", "application/json")
|
||
.json(request)
|
||
.send()
|
||
.await
|
||
.context("Failed to send request to DeepSeek")?;
|
||
|
||
let status = response.status();
|
||
|
||
if !status.is_success() {
|
||
let text = response.text().await.unwrap_or_default();
|
||
|
||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||
bail!(
|
||
"DeepSeek API error: {} ({})",
|
||
error.error.message,
|
||
error.error.error_type
|
||
);
|
||
}
|
||
|
||
bail!("DeepSeek API error: {} - {}", status, text);
|
||
}
|
||
|
||
let result: ChatCompletionResponse = response
|
||
.json()
|
||
.await
|
||
.context("Failed to parse DeepSeek response")?;
|
||
|
||
result
|
||
.choices
|
||
.into_iter()
|
||
.next()
|
||
.map(|c| c.message.content.trim().to_string())
|
||
.filter(|s| !s.is_empty())
|
||
.ok_or_else(|| anyhow::anyhow!("No response from DeepSeek"))
|
||
}
|
||
|
||
/// 流式请求(思考模式),处理 reasoning_content 和 content
|
||
async fn streaming_chat_completion(
|
||
&self,
|
||
url: &str,
|
||
request: &ChatCompletionRequest,
|
||
) -> Result<String> {
|
||
let response = self
|
||
.client
|
||
.post(url)
|
||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||
.header("Content-Type", "application/json")
|
||
.header("Accept", "text/event-stream")
|
||
.json(request)
|
||
.send()
|
||
.await
|
||
.context("Failed to send streaming request to DeepSeek")?;
|
||
|
||
let status = response.status();
|
||
|
||
if !status.is_success() {
|
||
let text = response.text().await.unwrap_or_default();
|
||
|
||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||
bail!(
|
||
"DeepSeek API error: {} ({})",
|
||
error.error.message,
|
||
error.error.error_type
|
||
);
|
||
}
|
||
|
||
bail!("DeepSeek API error: {} - {}", status, text);
|
||
}
|
||
|
||
let mut content_buffer = String::new();
|
||
let mut has_reasoning = false;
|
||
let mut has_content = false;
|
||
let mut stream_ended = false;
|
||
|
||
let thinking_state = self.thinking_state.as_ref();
|
||
|
||
let mut byte_stream = response.bytes_stream();
|
||
let mut line_buffer = String::new();
|
||
|
||
use futures_util::StreamExt;
|
||
|
||
while let Some(chunk) = byte_stream.next().await {
|
||
let chunk = chunk.context("Failed to read streaming response chunk")?;
|
||
let chunk_str =
|
||
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
|
||
|
||
line_buffer.push_str(&chunk_str);
|
||
|
||
// 处理完整行
|
||
while let Some(line_end) = line_buffer.find('\n') {
|
||
let line = line_buffer[..line_end].trim().to_string();
|
||
line_buffer = line_buffer[line_end + 1..].to_string();
|
||
|
||
if line.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
// SSE 格式:data: {...} 或 data: [DONE]
|
||
if line == "data: [DONE]" {
|
||
stream_ended = true;
|
||
break;
|
||
}
|
||
|
||
if let Some(json_str) = line.strip_prefix("data: ") {
|
||
match serde_json::from_str::<StreamChunk>(json_str) {
|
||
Ok(chunk) => {
|
||
for choice in &chunk.choices {
|
||
// 处理 reasoning_content
|
||
if let Some(ref reasoning) = choice.delta.reasoning_content
|
||
&& !reasoning.is_empty()
|
||
{
|
||
if !has_reasoning {
|
||
has_reasoning = true;
|
||
if let Some(state) = thinking_state {
|
||
state.start_thinking();
|
||
}
|
||
}
|
||
// reasoning_content 不对外输出,仅用于内部状态判断
|
||
continue;
|
||
}
|
||
|
||
// 处理 content
|
||
if let Some(ref content) = choice.delta.content
|
||
&& !content.is_empty()
|
||
{
|
||
// reasoning 结束,content 开始出现时移除 thinking 标识
|
||
if has_reasoning
|
||
&& !has_content
|
||
&& let Some(state) = thinking_state
|
||
{
|
||
state.end_thinking();
|
||
}
|
||
has_content = true;
|
||
content_buffer.push_str(content);
|
||
}
|
||
|
||
// 检查 finish_reason
|
||
if let Some(ref reason) = choice.finish_reason
|
||
&& reason == "stop"
|
||
{
|
||
stream_ended = true;
|
||
}
|
||
}
|
||
}
|
||
Err(_) => {
|
||
// 忽略无法解析的行(可能是心跳或注释)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if stream_ended {
|
||
break;
|
||
}
|
||
}
|
||
|
||
// 确保思考状态已结束
|
||
if let Some(state) = thinking_state {
|
||
state.end_thinking();
|
||
}
|
||
|
||
let result = content_buffer.trim().to_string();
|
||
|
||
if result.is_empty() {
|
||
if has_reasoning && !has_content {
|
||
bail!(
|
||
"DeepSeek returned reasoning content but no final answer. \
|
||
The model may have entered an incomplete thinking state. \
|
||
Please try again or disable thinking mode."
|
||
);
|
||
}
|
||
bail!(
|
||
"No response from DeepSeek. \
|
||
If thinking mode is enabled, try disabling it or ensure the model supports it."
|
||
);
|
||
}
|
||
|
||
Ok(result)
|
||
}
|
||
}
|
||
|
||
/// 可用 DeepSeek 模型列表
|
||
/// deepseek-chat / deepseek-reasoner 将于 2026-07-24 停用,推荐使用 V4 系列
|
||
pub const DEEPSEEK_MODELS: &[&str] = &[
|
||
"deepseek-v4-flash",
|
||
"deepseek-v4-pro",
|
||
// 兼容旧版模型 ID(将于 2026-07-24 停用)
|
||
"deepseek-chat",
|
||
"deepseek-reasoner",
|
||
];
|
||
|
||
pub fn is_valid_model(model: &str) -> bool {
|
||
DEEPSEEK_MODELS.contains(&model)
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_model_validation_v4() {
|
||
assert!(is_valid_model("deepseek-v4-flash"));
|
||
assert!(is_valid_model("deepseek-v4-pro"));
|
||
assert!(is_valid_model("deepseek-chat"));
|
||
assert!(is_valid_model("deepseek-reasoner"));
|
||
assert!(!is_valid_model("invalid-model"));
|
||
assert!(!is_valid_model("deepseek-v3"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_client_builder_defaults() {
|
||
let client = DeepSeekClient::new("test-key", "deepseek-v4-flash").unwrap();
|
||
assert!(!client.thinking_enabled);
|
||
assert_eq!(client.max_tokens, 500);
|
||
assert_eq!(client.temperature, 0.7);
|
||
assert!(client.reasoning_effort.is_none());
|
||
assert!(client.thinking_state.is_none());
|
||
}
|
||
|
||
#[test]
|
||
fn test_client_builder_with_thinking() {
|
||
let client = DeepSeekClient::new("test-key", "deepseek-v4-flash")
|
||
.unwrap()
|
||
.with_thinking(true)
|
||
.with_reasoning_effort(Some("high".to_string()))
|
||
.with_max_tokens(1000)
|
||
.with_temperature(0.5);
|
||
|
||
assert!(client.thinking_enabled);
|
||
assert_eq!(client.reasoning_effort, Some("high".to_string()));
|
||
assert_eq!(client.max_tokens, 1000);
|
||
assert_eq!(client.temperature, 0.5);
|
||
}
|
||
|
||
#[test]
|
||
fn test_thinking_config_serialization() {
|
||
let config = ThinkingConfig {
|
||
thinking_type: "enabled".to_string(),
|
||
};
|
||
let json = serde_json::to_string(&config).unwrap();
|
||
assert_eq!(json, r#"{"type":"enabled"}"#);
|
||
}
|
||
|
||
#[test]
|
||
fn test_message_serialization_without_reasoning() {
|
||
let msg = Message {
|
||
role: "user".to_string(),
|
||
content: "Hello".to_string(),
|
||
reasoning_content: None,
|
||
};
|
||
let json = serde_json::to_string(&msg).unwrap();
|
||
assert!(!json.contains("reasoning_content"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_stream_delta_parsing() {
|
||
let json = r#"{"content":"Hello","reasoning_content":null}"#;
|
||
let delta: StreamDelta = serde_json::from_str(json).unwrap();
|
||
assert_eq!(delta.content, Some("Hello".to_string()));
|
||
assert!(delta.reasoning_content.is_none());
|
||
}
|
||
|
||
#[test]
|
||
fn test_stream_delta_reasoning_only() {
|
||
let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#;
|
||
let delta: StreamDelta = serde_json::from_str(json).unwrap();
|
||
assert!(delta.content.is_none());
|
||
assert_eq!(delta.reasoning_content, Some("Let me think...".to_string()));
|
||
}
|
||
}
|