use super::{create_http_client, LlmProvider}; use anyhow::{bail, Context, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::time::Duration; /// OpenRouter API client pub struct OpenRouterClient { base_url: String, api_key: String, model: String, client: reqwest::Client, } #[derive(Debug, Serialize)] struct ChatCompletionRequest { model: String, messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, stream: bool, } #[derive(Debug, Serialize, Deserialize)] struct Message { role: String, content: String, } #[derive(Debug, Deserialize)] struct ChatCompletionResponse { choices: Vec, } #[derive(Debug, Deserialize)] struct Choice { message: Message, } #[derive(Debug, Deserialize)] struct ErrorResponse { error: ApiError, } #[derive(Debug, Deserialize)] struct ApiError { message: String, #[serde(rename = "type")] error_type: String, } impl OpenRouterClient { /// Create new OpenRouter client pub fn new(api_key: &str, model: &str) -> Result { let client = create_http_client(Duration::from_secs(60))?; Ok(Self { base_url: "https://openrouter.ai/api/v1".to_string(), api_key: api_key.to_string(), model: model.to_string(), client, }) } /// Create with custom base URL pub fn with_base_url(api_key: &str, model: &str, base_url: &str) -> Result { let client = create_http_client(Duration::from_secs(60))?; Ok(Self { base_url: base_url.trim_end_matches('/').to_string(), api_key: api_key.to_string(), model: model.to_string(), client, }) } /// Set timeout pub fn with_timeout(mut self, timeout: Duration) -> Result { self.client = create_http_client(timeout)?; Ok(self) } /// List available models pub async fn list_models(&self) -> Result> { let url = format!("{}/models", self.base_url); let response = self.client .get(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .header("HTTP-Referer", "https://quicommit.dev") .header("X-Title", "QuiCommit") .send() .await .context("Failed to list OpenRouter models")?; if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); bail!("OpenRouter API error: {} - {}", status, text); } #[derive(Deserialize)] struct ModelsResponse { data: Vec, } #[derive(Deserialize)] struct Model { id: String, } let result: ModelsResponse = response .json() .await .context("Failed to parse OpenRouter response")?; Ok(result.data.into_iter().map(|m| m.id).collect()) } /// Validate API key pub async fn validate_key(&self) -> Result { match self.list_models().await { Ok(_) => Ok(true), Err(e) => { let err_str = e.to_string(); if err_str.contains("401") || err_str.contains("Unauthorized") { Ok(false) } else { Err(e) } } } } } #[async_trait] impl LlmProvider for OpenRouterClient { async fn generate(&self, prompt: &str) -> Result { let messages = vec![ Message { role: "user".to_string(), content: prompt.to_string(), }, ]; self.chat_completion(messages).await } async fn generate_with_system(&self, system: &str, user: &str) -> Result { let mut messages = vec![]; if !system.is_empty() { messages.push(Message { role: "system".to_string(), content: system.to_string(), }); } messages.push(Message { role: "user".to_string(), content: user.to_string(), }); self.chat_completion(messages).await } async fn is_available(&self) -> bool { self.validate_key().await.unwrap_or(false) } fn name(&self) -> &str { "openrouter" } } impl OpenRouterClient { async fn chat_completion(&self, messages: Vec) -> Result { let url = format!("{}/chat/completions", self.base_url); let request = ChatCompletionRequest { model: self.model.clone(), messages, max_tokens: Some(500), temperature: Some(0.7), stream: false, }; let response = self.client .post(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .header("Content-Type", "application/json") .header("HTTP-Referer", "https://quicommit.dev") .header("X-Title", "QuiCommit") .json(&request) .send() .await .context("Failed to send request to OpenRouter")?; let status = response.status(); if !status.is_success() { let text = response.text().await.unwrap_or_default(); // Try to parse error if let Ok(error) = serde_json::from_str::(&text) { bail!("OpenRouter API error: {} ({})", error.error.message, error.error.error_type); } bail!("OpenRouter API error: {} - {}", status, text); } let result: ChatCompletionResponse = response .json() .await .context("Failed to parse OpenRouter response")?; result.choices .into_iter() .next() .map(|c| c.message.content.trim().to_string()) .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) } } /// Popular OpenRouter models pub const OPENROUTER_MODELS: &[&str] = &[ "openai/gpt-3.5-turbo", "openai/gpt-4", "openai/gpt-4-turbo", "anthropic/claude-3-opus", "anthropic/claude-3-sonnet", "anthropic/claude-3-haiku", "google/gemini-pro", "meta-llama/llama-2-70b-chat", "mistralai/mixtral-8x7b-instruct", "01-ai/yi-34b-chat", ]; /// Check if a model name is valid pub fn is_valid_model(model: &str) -> bool { // Since OpenRouter supports many models, we'll allow any model name // but provide some popular ones as suggestions true } #[cfg(test)] mod tests { use super::*; #[test] fn test_model_validation() { assert!(is_valid_model("openai/gpt-4")); assert!(is_valid_model("custom/model")); } }