feat(deepseek): 添加 DeepSeek reasoning 模式支持
This commit is contained in:
@@ -10,6 +10,7 @@ pub struct DeepSeekClient {
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
thinking_enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -21,6 +22,14 @@ struct ChatCompletionRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
thinking: Option<ThinkingConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ThinkingConfig {
|
||||
#[serde(rename = "type")]
|
||||
thinking_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -37,6 +46,8 @@ struct ChatCompletionResponse {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: Message,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -55,24 +66,26 @@ impl DeepSeekClient {
|
||||
/// Create new DeepSeek client
|
||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
base_url: "https://api.deepseek.com/v1".to_string(),
|
||||
base_url: "https://api.deepseek.com/".to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
thinking_enabled: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with custom base URL
|
||||
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
thinking_enabled: false,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -82,6 +95,12 @@ impl DeepSeekClient {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Enable or disable thinking mode
|
||||
pub fn with_thinking(mut self, enabled: bool) -> Self {
|
||||
self.thinking_enabled = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
/// List available models
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
let url = format!("{}/models", self.base_url);
|
||||
@@ -142,25 +161,25 @@ impl LlmProvider for DeepSeekClient {
|
||||
content: prompt.to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let mut messages = vec![];
|
||||
|
||||
|
||||
if !system.is_empty() {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: system.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
});
|
||||
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
@@ -176,15 +195,24 @@ impl LlmProvider for DeepSeekClient {
|
||||
impl DeepSeekClient {
|
||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
|
||||
let thinking = if self.thinking_enabled {
|
||||
Some(ThinkingConfig {
|
||||
thinking_type: "enabled".to_string(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
max_tokens: Some(500),
|
||||
temperature: Some(0.7),
|
||||
stream: false,
|
||||
thinking,
|
||||
};
|
||||
|
||||
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
@@ -193,25 +221,24 @@ impl DeepSeekClient {
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to DeepSeek")?;
|
||||
|
||||
|
||||
let status = response.status();
|
||||
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
// Try to parse error
|
||||
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!("DeepSeek API error: {} ({})", error.error.message, error.error.error_type);
|
||||
}
|
||||
|
||||
|
||||
bail!("DeepSeek API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
|
||||
let result: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse DeepSeek response")?;
|
||||
|
||||
|
||||
result.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
@@ -223,7 +250,7 @@ impl DeepSeekClient {
|
||||
/// Available DeepSeek models
|
||||
pub const DEEPSEEK_MODELS: &[&str] = &[
|
||||
"deepseek-chat",
|
||||
"deepseek-coder",
|
||||
"deepseek-reasoner",
|
||||
];
|
||||
|
||||
/// Check if a model name is valid
|
||||
@@ -238,6 +265,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_model_validation() {
|
||||
assert!(is_valid_model("deepseek-chat"));
|
||||
assert!(is_valid_model("deepseek-reasoner"));
|
||||
assert!(!is_valid_model("invalid-model"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user