feat:Add 3 new LLM providers and optimize the readme.
This commit is contained in:
229
src/llm/openrouter.rs
Normal file
229
src/llm/openrouter.rs
Normal file
@@ -0,0 +1,229 @@
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
/// OpenRouter API client
|
||||
pub struct OpenRouterClient {
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: ApiError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiError {
|
||||
message: String,
|
||||
#[serde(rename = "type")]
|
||||
error_type: String,
|
||||
}
|
||||
|
||||
impl OpenRouterClient {
|
||||
/// Create new OpenRouter client
|
||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: "https://openrouter.ai/api/v1".to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with custom base URL
|
||||
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set timeout
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||
self.client = create_http_client(timeout)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Validate API key
|
||||
pub async fn validate_key(&self) -> Result<bool> {
|
||||
let url = format!("{}/models", self.base_url);
|
||||
|
||||
let response = self.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("HTTP-Referer", "https://quicommit.dev")
|
||||
.header("X-Title", "QuiCommit")
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to validate OpenRouter API key")?;
|
||||
|
||||
if response.status().is_success() {
|
||||
Ok(true)
|
||||
} else if response.status().as_u16() == 401 {
|
||||
Ok(false)
|
||||
} else {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
bail!("OpenRouter API error: {} - {}", status, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenRouterClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let mut messages = vec![];
|
||||
|
||||
if !system.is_empty() {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: system.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
});
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
self.validate_key().await.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"openrouter"
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenRouterClient {
|
||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
max_tokens: Some(500),
|
||||
temperature: Some(0.7),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("HTTP-Referer", "https://quicommit.dev")
|
||||
.header("X-Title", "QuiCommit")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to OpenRouter")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
// Try to parse error
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!("OpenRouter API error: {} ({})", error.error.message, error.error.error_type);
|
||||
}
|
||||
|
||||
bail!("OpenRouter API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let result: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse OpenRouter response")?;
|
||||
|
||||
result.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content.trim().to_string())
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Popular OpenRouter models
|
||||
pub const OPENROUTER_MODELS: &[&str] = &[
|
||||
"openai/gpt-3.5-turbo",
|
||||
"openai/gpt-4",
|
||||
"openai/gpt-4-turbo",
|
||||
"anthropic/claude-3-opus",
|
||||
"anthropic/claude-3-sonnet",
|
||||
"anthropic/claude-3-haiku",
|
||||
"google/gemini-pro",
|
||||
"meta-llama/llama-2-70b-chat",
|
||||
"mistralai/mixtral-8x7b-instruct",
|
||||
"01-ai/yi-34b-chat",
|
||||
];
|
||||
|
||||
/// Check if a model name is valid
|
||||
pub fn is_valid_model(model: &str) -> bool {
|
||||
// Since OpenRouter supports many models, we'll allow any model name
|
||||
// but provide some popular ones as suggestions
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_validation() {
|
||||
assert!(is_valid_model("openai/gpt-4"));
|
||||
assert!(is_valid_model("custom/model"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user