257 lines
7.2 KiB
Rust
257 lines
7.2 KiB
Rust
use super::{create_http_client, LlmProvider};
|
|
use anyhow::{bail, Context, Result};
|
|
use async_trait::async_trait;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::time::Duration;
|
|
|
|
/// OpenRouter API client
|
|
pub struct OpenRouterClient {
|
|
base_url: String,
|
|
api_key: String,
|
|
model: String,
|
|
client: reqwest::Client,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct ChatCompletionRequest {
|
|
model: String,
|
|
messages: Vec<Message>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
max_tokens: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
temperature: Option<f32>,
|
|
stream: bool,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct Message {
|
|
role: String,
|
|
content: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ChatCompletionResponse {
|
|
choices: Vec<Choice>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct Choice {
|
|
message: Message,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ErrorResponse {
|
|
error: ApiError,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ApiError {
|
|
message: String,
|
|
#[serde(rename = "type")]
|
|
error_type: String,
|
|
}
|
|
|
|
impl OpenRouterClient {
|
|
/// Create new OpenRouter client
|
|
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
|
let client = create_http_client(Duration::from_secs(60))?;
|
|
|
|
Ok(Self {
|
|
base_url: "https://openrouter.ai/api/v1".to_string(),
|
|
api_key: api_key.to_string(),
|
|
model: model.to_string(),
|
|
client,
|
|
})
|
|
}
|
|
|
|
/// Create with custom base URL
|
|
pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result<Self> {
|
|
let client = create_http_client(Duration::from_secs(60))?;
|
|
|
|
Ok(Self {
|
|
base_url: base_url.trim_end_matches('/').to_string(),
|
|
api_key: api_key.to_string(),
|
|
model: model.to_string(),
|
|
client,
|
|
})
|
|
}
|
|
|
|
/// Set timeout
|
|
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
|
self.client = create_http_client(timeout)?;
|
|
Ok(self)
|
|
}
|
|
|
|
/// List available models
|
|
pub async fn list_models(&self) -> Result<Vec<String>> {
|
|
let url = format!("{}/models", self.base_url);
|
|
|
|
let response = self.client
|
|
.get(&url)
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.header("HTTP-Referer", "https://quicommit.dev")
|
|
.header("X-Title", "QuiCommit")
|
|
.send()
|
|
.await
|
|
.context("Failed to list OpenRouter models")?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let text = response.text().await.unwrap_or_default();
|
|
bail!("OpenRouter API error: {} - {}", status, text);
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct ModelsResponse {
|
|
data: Vec<Model>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct Model {
|
|
id: String,
|
|
}
|
|
|
|
let result: ModelsResponse = response
|
|
.json()
|
|
.await
|
|
.context("Failed to parse OpenRouter response")?;
|
|
|
|
Ok(result.data.into_iter().map(|m| m.id).collect())
|
|
}
|
|
|
|
/// Validate API key
|
|
pub async fn validate_key(&self) -> Result<bool> {
|
|
match self.list_models().await {
|
|
Ok(_) => Ok(true),
|
|
Err(e) => {
|
|
let err_str = e.to_string();
|
|
if err_str.contains("401") || err_str.contains("Unauthorized") {
|
|
Ok(false)
|
|
} else {
|
|
Err(e)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LlmProvider for OpenRouterClient {
|
|
async fn generate(&self, prompt: &str) -> Result<String> {
|
|
let messages = vec![
|
|
Message {
|
|
role: "user".to_string(),
|
|
content: prompt.to_string(),
|
|
},
|
|
];
|
|
|
|
self.chat_completion(messages).await
|
|
}
|
|
|
|
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
|
let mut messages = vec![];
|
|
|
|
if !system.is_empty() {
|
|
messages.push(Message {
|
|
role: "system".to_string(),
|
|
content: system.to_string(),
|
|
});
|
|
}
|
|
|
|
messages.push(Message {
|
|
role: "user".to_string(),
|
|
content: user.to_string(),
|
|
});
|
|
|
|
self.chat_completion(messages).await
|
|
}
|
|
|
|
async fn is_available(&self) -> bool {
|
|
self.validate_key().await.unwrap_or(false)
|
|
}
|
|
|
|
fn name(&self) -> &str {
|
|
"openrouter"
|
|
}
|
|
}
|
|
|
|
impl OpenRouterClient {
|
|
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
|
let url = format!("{}/chat/completions", self.base_url);
|
|
|
|
let request = ChatCompletionRequest {
|
|
model: self.model.clone(),
|
|
messages,
|
|
max_tokens: Some(500),
|
|
temperature: Some(0.7),
|
|
stream: false,
|
|
};
|
|
|
|
let response = self.client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.header("Content-Type", "application/json")
|
|
.header("HTTP-Referer", "https://quicommit.dev")
|
|
.header("X-Title", "QuiCommit")
|
|
.json(&request)
|
|
.send()
|
|
.await
|
|
.context("Failed to send request to OpenRouter")?;
|
|
|
|
let status = response.status();
|
|
|
|
if !status.is_success() {
|
|
let text = response.text().await.unwrap_or_default();
|
|
|
|
// Try to parse error
|
|
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
|
bail!("OpenRouter API error: {} ({})", error.error.message, error.error.error_type);
|
|
}
|
|
|
|
bail!("OpenRouter API error: {} - {}", status, text);
|
|
}
|
|
|
|
let result: ChatCompletionResponse = response
|
|
.json()
|
|
.await
|
|
.context("Failed to parse OpenRouter response")?;
|
|
|
|
result.choices
|
|
.into_iter()
|
|
.next()
|
|
.map(|c| c.message.content.trim().to_string())
|
|
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
|
|
}
|
|
}
|
|
|
|
/// Popular OpenRouter models
|
|
pub const OPENROUTER_MODELS: &[&str] = &[
|
|
"openai/gpt-3.5-turbo",
|
|
"openai/gpt-4",
|
|
"openai/gpt-4-turbo",
|
|
"anthropic/claude-3-opus",
|
|
"anthropic/claude-3-sonnet",
|
|
"anthropic/claude-3-haiku",
|
|
"google/gemini-pro",
|
|
"meta-llama/llama-2-70b-chat",
|
|
"mistralai/mixtral-8x7b-instruct",
|
|
"01-ai/yi-34b-chat",
|
|
];
|
|
|
|
/// Check if a model name is valid
|
|
pub fn is_valid_model(model: &str) -> bool {
|
|
// Since OpenRouter supports many models, we'll allow any model name
|
|
// but provide some popular ones as suggestions
|
|
true
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_model_validation() {
|
|
assert!(is_valid_model("openai/gpt-4"));
|
|
assert!(is_valid_model("custom/model"));
|
|
}
|
|
} |