230 lines
6.0 KiB
Rust
230 lines
6.0 KiB
Rust
use super::{LlmProvider, create_http_client};
|
|
use anyhow::{Context, Result};
|
|
use async_trait::async_trait;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::time::Duration;
|
|
|
|
/// Ollama API client
|
|
pub struct OllamaClient {
|
|
base_url: String,
|
|
model: String,
|
|
client: reqwest::Client,
|
|
max_tokens: u32,
|
|
temperature: f32,
|
|
top_p: Option<f32>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct GenerateRequest {
|
|
model: String,
|
|
prompt: String,
|
|
system: Option<String>,
|
|
stream: bool,
|
|
options: GenerationOptions,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Default)]
|
|
struct GenerationOptions {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
temperature: Option<f32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
num_predict: Option<u32>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct GenerateResponse {
|
|
response: String,
|
|
done: bool,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ListModelsResponse {
|
|
models: Vec<ModelInfo>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ModelInfo {
|
|
name: String,
|
|
}
|
|
|
|
impl OllamaClient {
|
|
/// Create new Ollama client
|
|
pub fn new(base_url: &str, model: &str) -> Self {
|
|
let client =
|
|
create_http_client(Duration::from_secs(120)).expect("Failed to create HTTP client");
|
|
|
|
Self {
|
|
base_url: base_url.trim_end_matches('/').to_string(),
|
|
model: model.to_string(),
|
|
client,
|
|
max_tokens: 500,
|
|
temperature: 0.7,
|
|
top_p: None,
|
|
}
|
|
}
|
|
|
|
/// Set timeout
|
|
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
|
self.client = create_http_client(timeout).expect("Failed to create HTTP client");
|
|
self
|
|
}
|
|
|
|
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
|
|
self.max_tokens = max_tokens;
|
|
self
|
|
}
|
|
|
|
pub fn with_temperature(mut self, temperature: f32) -> Self {
|
|
self.temperature = temperature;
|
|
self
|
|
}
|
|
|
|
pub fn with_top_p(mut self, top_p: f32) -> Self {
|
|
self.top_p = Some(top_p);
|
|
self
|
|
}
|
|
|
|
/// List available models
|
|
pub async fn list_models(&self) -> Result<Vec<String>> {
|
|
let url = format!("{}/api/tags", self.base_url);
|
|
|
|
let response = self
|
|
.client
|
|
.get(&url)
|
|
.send()
|
|
.await
|
|
.context("Failed to list Ollama models")?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let text = response.text().await.unwrap_or_default();
|
|
anyhow::bail!("Ollama API error: {} - {}", status, text);
|
|
}
|
|
|
|
let result: ListModelsResponse = response
|
|
.json()
|
|
.await
|
|
.context("Failed to parse Ollama response")?;
|
|
|
|
Ok(result.models.into_iter().map(|m| m.name).collect())
|
|
}
|
|
|
|
/// Pull a model
|
|
pub async fn pull_model(&self, model: &str) -> Result<()> {
|
|
let url = format!("{}/api/pull", self.base_url);
|
|
|
|
let request = serde_json::json!({
|
|
"name": model,
|
|
"stream": false,
|
|
});
|
|
|
|
let response = self
|
|
.client
|
|
.post(&url)
|
|
.json(&request)
|
|
.send()
|
|
.await
|
|
.context("Failed to pull Ollama model")?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let text = response.text().await.unwrap_or_default();
|
|
anyhow::bail!("Ollama pull error: {} - {}", status, text);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Check if model exists
|
|
pub async fn model_exists(&self, model: &str) -> bool {
|
|
match self.list_models().await {
|
|
Ok(models) => models.contains(&model.to_string()),
|
|
Err(_) => false,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LlmProvider for OllamaClient {
|
|
async fn generate(&self, prompt: &str) -> Result<String> {
|
|
self.generate_with_system("", prompt).await
|
|
}
|
|
|
|
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
|
let url = format!("{}/api/generate", self.base_url);
|
|
|
|
let system = if system.is_empty() {
|
|
None
|
|
} else {
|
|
Some(system.to_string())
|
|
};
|
|
|
|
let request = GenerateRequest {
|
|
model: self.model.clone(),
|
|
prompt: user.to_string(),
|
|
system,
|
|
stream: false,
|
|
options: GenerationOptions {
|
|
temperature: Some(self.temperature),
|
|
num_predict: Some(self.max_tokens),
|
|
},
|
|
};
|
|
|
|
let response = self
|
|
.client
|
|
.post(&url)
|
|
.json(&request)
|
|
.send()
|
|
.await
|
|
.context("Failed to send request to Ollama")?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let text = response.text().await.unwrap_or_default();
|
|
anyhow::bail!("Ollama API error: {} - {}", status, text);
|
|
}
|
|
|
|
let result: GenerateResponse = response
|
|
.json()
|
|
.await
|
|
.context("Failed to parse Ollama response")?;
|
|
|
|
Ok(result.response.trim().to_string())
|
|
}
|
|
|
|
async fn is_available(&self) -> bool {
|
|
let url = format!("{}/api/tags", self.base_url);
|
|
|
|
match self.client.get(&url).send().await {
|
|
Ok(response) => response.status().is_success(),
|
|
Err(_) => false,
|
|
}
|
|
}
|
|
|
|
fn name(&self) -> &str {
|
|
"ollama"
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
// These tests require a running Ollama server
|
|
#[tokio::test]
|
|
#[ignore]
|
|
async fn test_ollama_connection() {
|
|
let client = OllamaClient::new("http://localhost:11434", "llama2");
|
|
assert!(client.is_available().await);
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[ignore]
|
|
async fn test_ollama_generate() {
|
|
let client = OllamaClient::new("http://localhost:11434", "llama2");
|
|
let response = client.generate("Hello, how are you?").await;
|
|
assert!(response.is_ok());
|
|
println!("Response: {}", response.unwrap());
|
|
}
|
|
}
|