Files
QuiCommit/src/llm/kimi.rs

587 lines
17 KiB
Rust

use super::thinking::ThinkingStateManager;
use super::{LlmProvider, create_http_client};
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
/// Kimi API client (Moonshot AI)
pub struct KimiClient {
base_url: String,
api_key: String,
model: String,
client: reqwest::Client,
thinking_enabled: bool,
max_tokens: u32,
temperature: f32,
thinking_state: Option<Arc<ThinkingStateManager>>,
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<ThinkingConfig>,
}
#[derive(Debug, Serialize)]
struct ThinkingConfig {
#[serde(rename = "type")]
thinking_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: Message,
#[serde(default)]
reasoning_content: Option<String>,
}
// --- Streaming response structures ---
#[derive(Debug, Deserialize)]
struct StreamChunk {
choices: Vec<StreamChoice>,
}
#[derive(Debug, Deserialize)]
struct StreamChoice {
delta: StreamDelta,
#[serde(default)]
finish_reason: Option<String>,
index: Option<u32>,
}
#[derive(Debug, Deserialize, Default)]
struct StreamDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: ApiError,
}
#[derive(Debug, Deserialize)]
struct ApiError {
message: String,
#[serde(rename = "type")]
error_type: String,
}
impl KimiClient {
pub fn new(api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(300))?;
Ok(Self {
base_url: "https://api.moonshot.cn/v1".to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
thinking_enabled: false,
max_tokens: 500,
temperature: 1.0,
thinking_state: None,
})
}
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(300))?;
Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
thinking_enabled: false,
max_tokens: 500,
temperature: 1.0,
thinking_state: None,
})
}
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?;
Ok(self)
}
pub fn with_thinking(mut self, enabled: bool) -> Self {
self.thinking_enabled = enabled;
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_thinking_state(mut self, state: Arc<ThinkingStateManager>) -> Self {
self.thinking_state = Some(state);
self
}
pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/models", self.base_url);
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
.context("Failed to list Kimi models")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
bail!("Kimi API error: {} - {}", status, text);
}
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<ModelId>,
}
#[derive(Deserialize)]
struct ModelId {
id: String,
}
let result: ModelsResponse = response
.json()
.await
.context("Failed to parse Kimi response")?;
Ok(result.data.into_iter().map(|m| m.id).collect())
}
pub async fn validate_key(&self) -> Result<bool> {
match self.list_models().await {
Ok(_) => Ok(true),
Err(e) => {
let err_str = e.to_string();
if err_str.contains("401") || err_str.contains("Unauthorized") {
Ok(false)
} else {
Err(e)
}
}
}
}
}
#[async_trait]
impl LlmProvider for KimiClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
reasoning_content: None,
}];
self.chat_completion_with_retry(messages).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![];
if !system.is_empty() {
messages.push(Message {
role: "system".to_string(),
content: system.to_string(),
reasoning_content: None,
});
}
messages.push(Message {
role: "user".to_string(),
content: user.to_string(),
reasoning_content: None,
});
self.chat_completion_with_retry(messages).await
}
async fn is_available(&self) -> bool {
self.validate_key().await.unwrap_or(false)
}
fn name(&self) -> &str {
"kimi"
}
}
impl KimiClient {
async fn chat_completion_with_retry(&self, messages: Vec<Message>) -> Result<String> {
let mut last_error = None;
for attempt in 1..=3 {
match self.chat_completion(messages.clone()).await {
Ok(result) => return Ok(result),
Err(e) => {
let err_msg = e.to_string();
let is_retryable = err_msg.contains("timeout")
|| err_msg.contains("connection")
|| err_msg.contains("temporary")
|| err_msg.contains("5")
&& (err_msg.contains("500")
|| err_msg.contains("502")
|| err_msg.contains("503")
|| err_msg.contains("504"));
if !is_retryable || attempt == 3 {
last_error = Some(e);
break;
}
tokio::time::sleep(Duration::from_millis(500 * 2u64.pow(attempt - 1))).await;
}
}
}
Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed after retries")))
}
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
let thinking = if self.thinking_enabled {
Some(ThinkingConfig {
thinking_type: "enabled".to_string(),
})
} else {
None
};
// 对于 kimi-k2.6 等支持思考模式的模型,使用默认 temperature 即可
// 思考模式下不显式指定 temperature
let temperature = if self.thinking_enabled {
None
} else {
Some(self.temperature)
};
let request = ChatCompletionRequest {
model: self.model.clone(),
messages: messages.clone(),
max_tokens: Some(self.max_tokens),
temperature,
stream: self.thinking_enabled,
thinking,
};
if self.thinking_enabled {
self.streaming_chat_completion(&url, &request).await
} else {
self.non_streaming_chat_completion(&url, &request).await
}
}
/// 非流式请求(非思考模式)
async fn non_streaming_chat_completion(
&self,
url: &str,
request: &ChatCompletionRequest,
) -> Result<String> {
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(request)
.send()
.await
.context("Failed to send request to Kimi")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"Kimi API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("Kimi API error: {} - {}", status, text);
}
let result: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse Kimi response")?;
result
.choices
.into_iter()
.next()
.map(|c| {
let content = c.message.content.trim().to_string();
if content.is_empty() {
c.reasoning_content
.or(c.message.reasoning_content)
.map(|r| r.trim().to_string())
.unwrap_or_default()
} else {
content
}
})
.filter(|s| !s.is_empty())
.ok_or_else(|| anyhow::anyhow!("No response from Kimi"))
}
/// 流式请求(思考模式),处理 reasoning_content 和 content
async fn streaming_chat_completion(
&self,
url: &str,
request: &ChatCompletionRequest,
) -> Result<String> {
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.json(request)
.send()
.await
.context("Failed to send streaming request to Kimi")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!(
"Kimi API error: {} ({})",
error.error.message,
error.error.error_type
);
}
bail!("Kimi API error: {} - {}", status, text);
}
let mut content_buffer = String::new();
let mut has_reasoning = false;
let mut has_content = false;
let mut stream_ended = false;
let thinking_state = self.thinking_state.as_ref();
let mut byte_stream = response.bytes_stream();
let mut line_buffer = String::new();
use futures_util::StreamExt;
while let Some(chunk) = byte_stream.next().await {
let chunk = chunk.context("Failed to read streaming response chunk")?;
let chunk_str =
String::from_utf8(chunk.to_vec()).context("Invalid UTF-8 in stream chunk")?;
line_buffer.push_str(&chunk_str);
while let Some(line_end) = line_buffer.find('\n') {
let line = line_buffer[..line_end].trim().to_string();
line_buffer = line_buffer[line_end + 1..].to_string();
if line.is_empty() {
continue;
}
if line == "data: [DONE]" {
stream_ended = true;
break;
}
if let Some(json_str) = line.strip_prefix("data: ") {
match serde_json::from_str::<StreamChunk>(json_str) {
Ok(chunk) => {
for choice in &chunk.choices {
if let Some(ref reasoning) = choice.delta.reasoning_content
&& !reasoning.is_empty()
{
if !has_reasoning {
has_reasoning = true;
if let Some(state) = thinking_state {
state.start_thinking();
}
}
continue;
}
if let Some(ref content) = choice.delta.content
&& !content.is_empty()
{
if has_reasoning
&& !has_content
&& let Some(state) = thinking_state
{
state.end_thinking();
}
has_content = true;
content_buffer.push_str(content);
}
if let Some(ref reason) = choice.finish_reason
&& reason == "stop"
{
stream_ended = true;
}
}
}
Err(_) => {
// 忽略无法解析的行
}
}
}
}
if stream_ended {
break;
}
}
// 确保思考状态已结束
if let Some(state) = thinking_state {
state.end_thinking();
}
let result = content_buffer.trim().to_string();
if result.is_empty() {
if has_reasoning && !has_content {
bail!(
"Kimi returned reasoning content but no final answer. \
The model may have entered an incomplete thinking state. \
Please try again or disable thinking mode."
);
}
bail!(
"No response from Kimi. \
If thinking mode is enabled, try disabling it or ensure the model supports it."
);
}
Ok(result)
}
}
/// 可用 Kimi 模型列表
pub const KIMI_MODELS: &[&str] = &[
// K2 系列(推荐)
"kimi-k2.6",
"kimi-k2.5",
"kimi-k2-thinking",
"kimi-k2-thinking-turbo",
"kimi-k2-instruct",
"kimi-k2-instruct-0905",
// 兼容旧版模型 ID
"moonshot-v1-8k",
"moonshot-v1-32k",
"moonshot-v1-128k",
];
pub fn is_valid_model(model: &str) -> bool {
KIMI_MODELS.contains(&model)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_validation_k2() {
assert!(is_valid_model("kimi-k2.6"));
assert!(is_valid_model("kimi-k2.5"));
assert!(is_valid_model("kimi-k2-thinking"));
assert!(is_valid_model("kimi-k2-thinking-turbo"));
assert!(is_valid_model("moonshot-v1-8k"));
assert!(is_valid_model("moonshot-v1-32k"));
assert!(is_valid_model("moonshot-v1-128k"));
assert!(!is_valid_model("invalid-model"));
assert!(!is_valid_model("kimi-k1.5"));
}
#[test]
fn test_client_builder_defaults() {
let client = KimiClient::new("test-key", "kimi-k2.6").unwrap();
assert!(!client.thinking_enabled);
assert_eq!(client.max_tokens, 500);
assert_eq!(client.temperature, 1.0);
assert!(client.thinking_state.is_none());
}
#[test]
fn test_client_builder_with_thinking() {
let client = KimiClient::new("test-key", "kimi-k2.6")
.unwrap()
.with_thinking(true)
.with_max_tokens(1000)
.with_temperature(0.5);
assert!(client.thinking_enabled);
assert_eq!(client.max_tokens, 1000);
assert_eq!(client.temperature, 0.5);
}
#[test]
fn test_thinking_config_serialization() {
let config = ThinkingConfig {
thinking_type: "enabled".to_string(),
};
let json = serde_json::to_string(&config).unwrap();
assert_eq!(json, r#"{"type":"enabled"}"#);
}
#[test]
fn test_client_new_defaults() {
let client = KimiClient::new("test-key", "kimi-k2.6").unwrap();
assert_eq!(client.name(), "kimi");
assert!(!client.thinking_enabled);
}
#[test]
fn test_message_serialization() {
let msg = Message {
role: "user".to_string(),
content: "Hello".to_string(),
reasoning_content: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(!json.contains("reasoning_content"));
}
}