Files
QuiCommit/src/llm/deepseek.rs

623 lines
19 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use super::thinking::ThinkingStateManager;
use super::{LlmProvider, create_http_client};
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
/// DeepSeek API client
pub struct DeepSeekClient {
base_url: String,
api_key: String,
model: String,
client: reqwest::Client,
thinking_enabled: bool,
reasoning_effort: Option<String>,
max_tokens: u32,
temperature: f32,
thinking_state: Option<Arc<ThinkingStateManager>>,
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
frequency_penalty: Option<f32>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<ThinkingConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_effort: Option<String>,
}
#[derive(Debug, Serialize)]
struct ThinkingConfig {
#[serde(rename = "type")]
thinking_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: Message,
#[serde(default)]
reasoning_content: Option<String>,
}
// --- Streaming response structures ---
#[derive(Debug, Deserialize)]
struct StreamChunk {
choices: Vec<StreamChoice>,
}
#[derive(Debug, Deserialize)]
struct StreamChoice {
delta: StreamDelta,
#[serde(default)]
finish_reason: Option<String>,
index: Option<u32>,
}
#[derive(Debug, Deserialize, Default)]
struct StreamDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: ApiError,
}
#[derive(Debug, Deserialize)]
struct ApiError {
message: String,
#[serde(rename = "type")]
error_type: String,
}
impl DeepSeekClient {
pub fn new(api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(300))?;
Ok(Self {
base_url: "https://api.deepseek.com".to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
thinking_enabled: false,
reasoning_effort: None,
max_tokens: 500,
temperature: 0.7,
thinking_state: None,
})
}
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(300))?;
Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
thinking_enabled: false,
reasoning_effort: None,
max_tokens: 500,
temperature: 0.7,
thinking_state: None,
})
}
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?;
Ok(self)
}
pub fn with_thinking(mut self, enabled: bool) -> Self {
self.thinking_enabled = enabled;
self
}
pub fn with_reasoning_effort(mut self, effort: Option<String>) -> Self {
self.reasoning_effort = effort;
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
self.thinking_state = Some(state);
self
}
pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/models", self.base_url);
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
.context("Failed to list DeepSeek models")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
bail!("DeepSeek API error: {} - {}", status, text);
}
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<ModelId>,
}
#[derive(Deserialize)]
struct ModelId {
id: String,
}
let result: ModelsResponse = response
.json()
.await
.context("Failed to parse DeepSeek response")?;
Ok(result.data.into_iter().map(|m| m.id).collect())
}
pub async fn validate_key(&self) -> Result<bool> {
match self.list_models().await {
Ok(_) => Ok(true),
Err(e) => {
let err_str = e.to_string();
if err_str.contains("401") || err_str.contains("Unauthorized") {
Ok(false)
} else {
Err(e)
}
}
}
}
}
#[async_trait]
impl LlmProvider for DeepSeekClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
reasoning_content: None,
}];
self.chat_completion_with_retry(messages).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![];
if !system.is_empty() {
messages.push(Message {
role: "system".to_string(),
content: system.to_string(),
reasoning_content: None,
});
}
messages.push(Message {
role: "user".to_string(),
content: user.to_string(),
reasoning_content: None,
});
self.chat_completion_with_retry(messages).await
}
async fn is_available(&self) -> bool {
self.validate_key().await.unwrap_or(false)
}
fn name(&self) -> &str {
"deepseek"
}
}
impl DeepSeekClient {
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
let mut last_error = None;
for attempt in 1..=3 {
match self.chat_completion(messages.clone()).await {
Ok(result) => return Ok(result),
Err(e) => {
let err_msg = e.to_string();
// 网络临时错误才重试
let is_retryable = err_msg.contains("timeout")
|| err_msg.contains("connection")
|| err_msg.contains("temporary")
|| err_msg.contains("5")
&& (err_msg.contains("500")
|| err_msg.contains("502")
|| err_msg.contains("503")
|| err_msg.contains("504"));
if !is_retryable || attempt == 3 {
last_error = Some(e);
break;
}
// 指数退避
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
}
}
}
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
}
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
let thinking = if self.thinking_enabled {
Some(ThinkingConfig {
thinking_type: "enabled".to_string(),
})
} else {
None
};
// 思考模式下temperature/top_p 等参数不应传递
// 非思考模式下可以正常传递
let (temperature, max_tokens, top_p, presence_penalty, frequency_penalty) =
if self.thinking_enabled {
(None, Some(self.max_tokens), None, None, None)
} else {
(
Some(self.temperature),
Some(self.max_tokens),
None,
None,
None,
)
};
let reasoning_effort = if self.thinking_enabled {
self.reasoning_effort.clone()
} else {
None
};
let request = ChatCompletionRequest {
model: self.model.clone(),
messages: messages.clone(),
max_tokens,
temperature,
top_p,
presence_penalty,
frequency_penalty,
stream: self.thinking_enabled,
thinking,
reasoning_effort,
};
if self.thinking_enabled {
self.streaming_chat_completion(&url, &request).await
} else {
self.non_streaming_chat_completion(&url, &request).await
}
}
/// 非流式请求(非思考模式)
async fn non_streaming_chat_completion(
&self,
url: &str,
request: &ChatCompletionRequest,
) -> Result<String> {
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(request)
.send()
.await
.context("Failed to send request to DeepSeek")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"DeepSeek API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("DeepSeek API error: {} - {}", status, text);
}
let result: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse DeepSeek response")?;
result
.choices
.into_iter()
.next()
.map(|c| c.message.content.trim().to_string())
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("No response from DeepSeek"))
}
/// 流式请求(思考模式),处理 reasoning_content 和 content
async fn streaming_chat_completion(
&self,
url: &str,
request: &ChatCompletionRequest,
) -> Result<String> {
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.json(request)
.send()
.await
.context("Failed to send streaming request to DeepSeek")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"DeepSeek API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("DeepSeek API error: {} - {}", status, text);
}
let mut content_buffer = String::new();
let mut has_reasoning = false;
let mut has_content = false;
let mut stream_ended = false;
let thinking_state = self.thinking_state.as_ref();
let mut byte_stream = response.bytes_stream();
let mut line_buffer = String::new();
use futures_util::StreamExt;
while let Some(chunk) = byte_stream.next().await {
let chunk = chunk.context("Failed to read streaming response chunk")?;
let chunk_str =
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
line_buffer.push_str(&chunk_str);
// 处理完整行
while let Some(line_end) = line_buffer.find('\n') {
let line = line_buffer[..line_end].trim().to_string();
line_buffer = line_buffer[line_end + 1..].to_string();
if line.is_empty() {
continue;
}
// SSE 格式data: {...} 或 data: [DONE]
if line == "data: [DONE]" {
stream_ended = true;
break;
}
if let Some(json_str) = line.strip_prefix("data: ") {
match serde_json::from_str::<StreamChunk>(json_str) {
Ok(chunk) => {
for choice in &chunk.choices {
// 处理 reasoning_content
if let Some(ref reasoning) = choice.delta.reasoning_content
&& !reasoning.is_empty()
{
if !has_reasoning {
has_reasoning = true;
if let Some(state) = thinking_state {
state.start_thinking();
}
}
// reasoning_content 不对外输出,仅用于内部状态判断
continue;
}
// 处理 content
if let Some(ref content) = choice.delta.content
&& !content.is_empty()
{
// reasoning 结束content 开始出现时移除 thinking 标识
if has_reasoning
&& !has_content
&& let Some(state) = thinking_state
{
state.end_thinking();
}
has_content = true;
content_buffer.push_str(content);
}
// 检查 finish_reason
if let Some(ref reason) = choice.finish_reason
&& reason == "stop"
{
stream_ended = true;
}
}
}
Err(_) => {
// 忽略无法解析的行(可能是心跳或注释)
}
}
}
}
if stream_ended {
break;
}
}
// 确保思考状态已结束
if let Some(state) = thinking_state {
state.end_thinking();
}
let result = content_buffer.trim().to_string();
if result.is_empty() {
if has_reasoning && !has_content {
bail!(
"DeepSeek returned reasoning content but no final answer. \
The model may have entered an incomplete thinking state. \
Please try again or disable thinking mode."
);
}
bail!(
"No response from DeepSeek. \
If thinking mode is enabled, try disabling it or ensure the model supports it."
);
}
Ok(result)
}
}
/// 可用 DeepSeek 模型列表
/// deepseek-chat / deepseek-reasoner 将于 2026-07-24 停用,推荐使用 V4 系列
pub const DEEPSEEK_MODELS: &[&str] = &[
"deepseek-v4-flash",
"deepseek-v4-pro",
// 兼容旧版模型 ID将于 2026-07-24 停用)
"deepseek-chat",
"deepseek-reasoner",
];
pub fn is_valid_model(model: &str) -> bool {
DEEPSEEK_MODELS.contains(&model)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_validation_v4() {
assert!(is_valid_model("deepseek-v4-flash"));
assert!(is_valid_model("deepseek-v4-pro"));
assert!(is_valid_model("deepseek-chat"));
assert!(is_valid_model("deepseek-reasoner"));
assert!(!is_valid_model("invalid-model"));
assert!(!is_valid_model("deepseek-v3"));
}
#[test]
fn test_client_builder_defaults() {
let client = DeepSeekClient::new("test-key", "deepseek-v4-flash").unwrap();
assert!(!client.thinking_enabled);
assert_eq!(client.max_tokens, 500);
assert_eq!(client.temperature, 0.7);
assert!(client.reasoning_effort.is_none());
assert!(client.thinking_state.is_none());
}
#[test]
fn test_client_builder_with_thinking() {
let client = DeepSeekClient::new("test-key", "deepseek-v4-flash")
.unwrap()
.with_thinking(true)
.with_reasoning_effort(Some("high".to_string()))
.with_max_tokens(1000)
.with_temperature(0.5);
assert!(client.thinking_enabled);
assert_eq!(client.reasoning_effort, Some("high".to_string()));
assert_eq!(client.max_tokens, 1000);
assert_eq!(client.temperature, 0.5);
}
#[test]
fn test_thinking_config_serialization() {
let config = ThinkingConfig {
thinking_type: "enabled".to_string(),
};
let json = serde_json::to_string(&config).unwrap();
assert_eq!(json, r#"{"type":"enabled"}"#);
}
#[test]
fn test_message_serialization_without_reasoning() {
let msg = Message {
role: "user".to_string(),
content: "Hello".to_string(),
reasoning_content: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(!json.contains("reasoning_content"));
}
#[test]
fn test_stream_delta_parsing() {
let json = r#"{"content":"Hello","reasoning_content":null}"#;
let delta: StreamDelta = serde_json::from_str(json).unwrap();
assert_eq!(delta.content, Some("Hello".to_string()));
assert!(delta.reasoning_content.is_none());
}
#[test]
fn test_stream_delta_reasoning_only() {
let json = r#"{"content":null,"reasoning_content":"Let me think..."}"#;
let delta: StreamDelta = serde_json::from_str(json).unwrap();
assert!(delta.content.is_none());
assert_eq!(delta.reasoning_content, Some("Let me think...".to_string()));
}
}