use super::{LlmProvider, create_http_client}; use anyhow::{Context, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::time::Duration; /// Ollama API client pub struct OllamaClient { base_url: String, model: String, client: reqwest::Client, max_tokens: u32, temperature: f32, top_p: Option, } #[derive(Debug, Serialize)] struct GenerateRequest { model: String, prompt: String, system: Option, stream: bool, options: GenerationOptions, } #[derive(Debug, Serialize, Default)] struct GenerationOptions { #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] num_predict: Option, } #[derive(Debug, Deserialize)] struct GenerateResponse { response: String, done: bool, } #[derive(Debug, Deserialize)] struct ListModelsResponse { models: Vec, } #[derive(Debug, Deserialize)] struct ModelInfo { name: String, } impl OllamaClient { /// Create new Ollama client pub fn new(base_url: &str, model: &str) -> Self { let client = create_http_client(Duration::from_secs(120)).expect("Failed to create HTTP client"); Self { base_url: base_url.trim_end_matches('/').to_string(), model: model.to_string(), client, max_tokens: 500, temperature: 0.7, top_p: None, } } /// Set timeout pub fn with_timeout(mut self, timeout: Duration) -> Self { self.client = create_http_client(timeout).expect("Failed to create HTTP client"); self } pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { self.max_tokens = max_tokens; self } pub fn with_temperature(mut self, temperature: f32) -> Self { self.temperature = temperature; self } pub fn with_top_p(mut self, top_p: f32) -> Self { self.top_p = Some(top_p); self } /// List available models pub async fn list_models(&self) -> Result> { let url = format!("{}/api/tags", self.base_url); let response = self .client .get(&url) .send() .await .context("Failed to list Ollama models")?; if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); anyhow::bail!("Ollama API error: {} - {}", status, text); } let result: ListModelsResponse = response .json() .await .context("Failed to parse Ollama response")?; Ok(result.models.into_iter().map(|m| m.name).collect()) } /// Pull a model pub async fn pull_model(&self, model: &str) -> Result<()> { let url = format!("{}/api/pull", self.base_url); let request = serde_json::json!({ "name": model, "stream": false, }); let response = self .client .post(&url) .json(&request) .send() .await .context("Failed to pull Ollama model")?; if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); anyhow::bail!("Ollama pull error: {} - {}", status, text); } Ok(()) } /// Check if model exists pub async fn model_exists(&self, model: &str) -> bool { match self.list_models().await { Ok(models) => models.contains(&model.to_string()), Err(_) => false, } } } #[async_trait] impl LlmProvider for OllamaClient { async fn generate(&self, prompt: &str) -> Result { self.generate_with_system("", prompt).await } async fn generate_with_system(&self, system: &str, user: &str) -> Result { let url = format!("{}/api/generate", self.base_url); let system = if system.is_empty() { None } else { Some(system.to_string()) }; let request = GenerateRequest { model: self.model.clone(), prompt: user.to_string(), system, stream: false, options: GenerationOptions { temperature: Some(self.temperature), num_predict: Some(self.max_tokens), }, }; let response = self .client .post(&url) .json(&request) .send() .await .context("Failed to send request to Ollama")?; if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); anyhow::bail!("Ollama API error: {} - {}", status, text); } let result: GenerateResponse = response .json() .await .context("Failed to parse Ollama response")?; Ok(result.response.trim().to_string()) } async fn is_available(&self) -> bool { let url = format!("{}/api/tags", self.base_url); match self.client.get(&url).send().await { Ok(response) => response.status().is_success(), Err(_) => false, } } fn name(&self) -> &str { "ollama" } } #[cfg(test)] mod tests { use super::*; // These tests require a running Ollama server #[tokio::test] #[ignore] async fn test_ollama_connection() { let client = OllamaClient::new("http://localhost:11434", "llama2"); assert!(client.is_available().await); } #[tokio::test] #[ignore] async fn test_ollama_generate() { let client = OllamaClient::new("http://localhost:11434", "llama2"); let response = client.generate("Hello, how are you?").await; assert!(response.is_ok()); println!("Response: {}", response.unwrap()); } }