feat:(first commit)created repository and complete 0.1.0
This commit is contained in:
227
src/llm/anthropic.rs
Normal file
227
src/llm/anthropic.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Anthropic Claude API client
|
||||
pub struct AnthropicClient {
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct MessagesRequest {
|
||||
model: String,
|
||||
max_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
messages: Vec<AnthropicMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct AnthropicMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessagesResponse {
|
||||
content: Vec<ContentBlock>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ContentBlock {
|
||||
#[serde(rename = "type")]
|
||||
content_type: String,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: AnthropicError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AnthropicError {
|
||||
#[serde(rename = "type")]
|
||||
error_type: String,
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
/// Create new Anthropic client
|
||||
pub fn new(api_key: &str, model: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
Ok(Self {
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set timeout
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||
self.client = create_http_client(timeout)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Validate API key
|
||||
pub async fn validate_key(&self) -> Result<bool> {
|
||||
let url = "https://api.anthropic.com/v1/messages";
|
||||
|
||||
let request = MessagesRequest {
|
||||
model: self.model.clone(),
|
||||
max_tokens: 5,
|
||||
temperature: Some(0.0),
|
||||
messages: vec![AnthropicMessage {
|
||||
role: "user".to_string(),
|
||||
content: "Hi".to_string(),
|
||||
}],
|
||||
system: None,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
.post(url)
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match response {
|
||||
Ok(resp) => {
|
||||
if resp.status().is_success() {
|
||||
Ok(true)
|
||||
} else {
|
||||
let status = resp.status();
|
||||
if status.as_u16() == 401 {
|
||||
Ok(false)
|
||||
} else {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
bail!("Anthropic API error: {} - {}", status, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for AnthropicClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
let messages = vec![AnthropicMessage {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
}];
|
||||
|
||||
self.messages_request(messages, None).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let messages = vec![AnthropicMessage {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
}];
|
||||
|
||||
let system = if system.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(system.to_string())
|
||||
};
|
||||
|
||||
self.messages_request(messages, system).await
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
self.validate_key().await.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"anthropic"
|
||||
}
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
async fn messages_request(
|
||||
&self,
|
||||
messages: Vec<AnthropicMessage>,
|
||||
system: Option<String>,
|
||||
) -> Result<String> {
|
||||
let url = "https://api.anthropic.com/v1/messages";
|
||||
|
||||
let request = MessagesRequest {
|
||||
model: self.model.clone(),
|
||||
max_tokens: 500,
|
||||
temperature: Some(0.7),
|
||||
messages,
|
||||
system,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
.post(url)
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to Anthropic")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
// Try to parse error
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!("Anthropic API error: {} ({})", error.error.message, error.error.error_type);
|
||||
}
|
||||
|
||||
bail!("Anthropic API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let result: MessagesResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Anthropic response")?;
|
||||
|
||||
result.content
|
||||
.into_iter()
|
||||
.find(|c| c.content_type == "text")
|
||||
.map(|c| c.text.trim().to_string())
|
||||
.ok_or_else(|| anyhow::anyhow!("No text response from Anthropic"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Available Anthropic models
|
||||
pub const ANTHROPIC_MODELS: &[&str] = &[
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-2.1",
|
||||
"claude-2.0",
|
||||
"claude-instant-1.2",
|
||||
];
|
||||
|
||||
/// Check if a model name is valid
|
||||
pub fn is_valid_model(model: &str) -> bool {
|
||||
ANTHROPIC_MODELS.contains(&model)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_validation() {
|
||||
assert!(is_valid_model("claude-3-sonnet-20240229"));
|
||||
assert!(!is_valid_model("invalid-model"));
|
||||
}
|
||||
}
|
||||
433
src/llm/mod.rs
Normal file
433
src/llm/mod.rs
Normal file
@@ -0,0 +1,433 @@
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
pub mod anthropic;
|
||||
|
||||
pub use ollama::OllamaClient;
|
||||
pub use openai::OpenAiClient;
|
||||
pub use anthropic::AnthropicClient;
|
||||
|
||||
/// LLM provider trait
|
||||
#[async_trait]
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
/// Generate text from prompt
|
||||
async fn generate(&self, prompt: &str) -> Result<String>;
|
||||
|
||||
/// Generate with system prompt
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String>;
|
||||
|
||||
/// Check if provider is available
|
||||
async fn is_available(&self) -> bool;
|
||||
|
||||
/// Get provider name
|
||||
fn name(&self) -> &str;
|
||||
}
|
||||
|
||||
/// LLM client that wraps different providers
|
||||
pub struct LlmClient {
|
||||
provider: Box<dyn LlmProvider>,
|
||||
config: LlmClientConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LlmClientConfig {
|
||||
pub max_tokens: u32,
|
||||
pub temperature: f32,
|
||||
pub timeout: Duration,
|
||||
}
|
||||
|
||||
impl Default for LlmClientConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_tokens: 500,
|
||||
temperature: 0.7,
|
||||
timeout: Duration::from_secs(30),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LlmClient {
|
||||
/// Create LLM client from configuration
|
||||
pub async fn from_config(config: &crate::config::LlmConfig) -> Result<Self> {
|
||||
let client_config = LlmClientConfig {
|
||||
max_tokens: config.max_tokens,
|
||||
temperature: config.temperature,
|
||||
timeout: Duration::from_secs(config.timeout),
|
||||
};
|
||||
|
||||
let provider: Box<dyn LlmProvider> = match config.provider.as_str() {
|
||||
"ollama" => {
|
||||
Box::new(OllamaClient::new(&config.ollama.url, &config.ollama.model))
|
||||
}
|
||||
"openai" => {
|
||||
let api_key = config.openai.api_key.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenAI API key not configured"))?;
|
||||
Box::new(OpenAiClient::new(
|
||||
&config.openai.base_url,
|
||||
api_key,
|
||||
&config.openai.model,
|
||||
)?)
|
||||
}
|
||||
"anthropic" => {
|
||||
let api_key = config.anthropic.api_key.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Anthropic API key not configured"))?;
|
||||
Box::new(AnthropicClient::new(api_key, &config.anthropic.model)?)
|
||||
}
|
||||
_ => bail!("Unknown LLM provider: {}", config.provider),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
provider,
|
||||
config: client_config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with specific provider
|
||||
pub fn with_provider(provider: Box<dyn LlmProvider>) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
config: LlmClientConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate commit message from git diff
|
||||
pub async fn generate_commit_message(
|
||||
&self,
|
||||
diff: &str,
|
||||
format: crate::config::CommitFormat,
|
||||
) -> Result<GeneratedCommit> {
|
||||
let system_prompt = match format {
|
||||
crate::config::CommitFormat::Conventional => {
|
||||
CONVENTIONAL_COMMIT_SYSTEM_PROMPT
|
||||
}
|
||||
crate::config::CommitFormat::Commitlint => {
|
||||
COMMITLINT_SYSTEM_PROMPT
|
||||
}
|
||||
};
|
||||
|
||||
let prompt = format!("{}", diff);
|
||||
let response = self.provider.generate_with_system(system_prompt, &prompt).await?;
|
||||
|
||||
self.parse_commit_response(&response, format)
|
||||
}
|
||||
|
||||
/// Generate tag message from commits
|
||||
pub async fn generate_tag_message(
|
||||
&self,
|
||||
version: &str,
|
||||
commits: &[String],
|
||||
) -> Result<String> {
|
||||
let system_prompt = TAG_MESSAGE_SYSTEM_PROMPT;
|
||||
let commits_text = commits.join("\n");
|
||||
let prompt = format!("Version: {}\n\nCommits:\n{}", version, commits_text);
|
||||
|
||||
self.provider.generate_with_system(system_prompt, &prompt).await
|
||||
}
|
||||
|
||||
/// Generate changelog entry
|
||||
pub async fn generate_changelog_entry(
|
||||
&self,
|
||||
version: &str,
|
||||
commits: &[(String, String)], // (type, message)
|
||||
) -> Result<String> {
|
||||
let system_prompt = CHANGELOG_SYSTEM_PROMPT;
|
||||
|
||||
let commits_text = commits
|
||||
.iter()
|
||||
.map(|(t, m)| format!("- [{}] {}", t, m))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let prompt = format!("Version: {}\n\nCommits:\n{}", version, commits_text);
|
||||
|
||||
self.provider.generate_with_system(system_prompt, &prompt).await
|
||||
}
|
||||
|
||||
/// Check if provider is available
|
||||
pub async fn is_available(&self) -> bool {
|
||||
self.provider.is_available().await
|
||||
}
|
||||
|
||||
/// Parse commit response from LLM
|
||||
fn parse_commit_response(&self, response: &str, format: crate::config::CommitFormat) -> Result<GeneratedCommit> {
|
||||
let lines: Vec<&str> = response.lines().collect();
|
||||
|
||||
if lines.is_empty() {
|
||||
bail!("Empty response from LLM");
|
||||
}
|
||||
|
||||
let first_line = lines[0];
|
||||
|
||||
// Parse based on format
|
||||
match format {
|
||||
crate::config::CommitFormat::Conventional => {
|
||||
self.parse_conventional_commit(first_line, lines)
|
||||
}
|
||||
crate::config::CommitFormat::Commitlint => {
|
||||
self.parse_commitlint_commit(first_line, lines)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_conventional_commit(
|
||||
&self,
|
||||
first_line: &str,
|
||||
lines: Vec<&str>,
|
||||
) -> Result<GeneratedCommit> {
|
||||
// Parse: type(scope)!: description
|
||||
let parts: Vec<&str> = first_line.splitn(2, ':').collect();
|
||||
if parts.len() != 2 {
|
||||
bail!("Invalid conventional commit format: missing colon");
|
||||
}
|
||||
|
||||
let type_part = parts[0];
|
||||
let description = parts[1].trim();
|
||||
|
||||
// Extract type, scope, and breaking indicator
|
||||
let breaking = type_part.ends_with('!');
|
||||
let type_part = type_part.trim_end_matches('!');
|
||||
|
||||
let (commit_type, scope) = if let Some(start) = type_part.find('(') {
|
||||
if let Some(end) = type_part.find(')') {
|
||||
let t = &type_part[..start];
|
||||
let s = &type_part[start + 1..end];
|
||||
(t.to_string(), Some(s.to_string()))
|
||||
} else {
|
||||
bail!("Invalid scope format: missing closing parenthesis");
|
||||
}
|
||||
} else {
|
||||
(type_part.to_string(), None)
|
||||
};
|
||||
|
||||
// Extract body and footer
|
||||
let (body, footer) = self.extract_body_footer(&lines);
|
||||
|
||||
Ok(GeneratedCommit {
|
||||
commit_type,
|
||||
scope,
|
||||
description: description.to_string(),
|
||||
body,
|
||||
footer,
|
||||
breaking,
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_commitlint_commit(
|
||||
&self,
|
||||
first_line: &str,
|
||||
lines: Vec<&str>,
|
||||
) -> Result<GeneratedCommit> {
|
||||
// Similar parsing but with commitlint rules
|
||||
let parts: Vec<&str> = first_line.splitn(2, ':').collect();
|
||||
if parts.len() != 2 {
|
||||
bail!("Invalid commit format: missing colon");
|
||||
}
|
||||
|
||||
let type_part = parts[0];
|
||||
let subject = parts[1].trim();
|
||||
|
||||
let (commit_type, scope) = if let Some(start) = type_part.find('(') {
|
||||
if let Some(end) = type_part.find(')') {
|
||||
let t = &type_part[..start];
|
||||
let s = &type_part[start + 1..end];
|
||||
(t.to_string(), Some(s.to_string()))
|
||||
} else {
|
||||
(type_part.to_string(), None)
|
||||
}
|
||||
} else {
|
||||
(type_part.to_string(), None)
|
||||
};
|
||||
|
||||
let (body, footer) = self.extract_body_footer(&lines);
|
||||
|
||||
Ok(GeneratedCommit {
|
||||
commit_type,
|
||||
scope,
|
||||
description: subject.to_string(),
|
||||
body,
|
||||
footer,
|
||||
breaking: false,
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_body_footer(&self, lines: &[&str]) -> (Option<String>, Option<String>) {
|
||||
if lines.len() <= 1 {
|
||||
return (None, None);
|
||||
}
|
||||
|
||||
let rest: Vec<&str> = lines[1..]
|
||||
.iter()
|
||||
.skip_while(|l| l.trim().is_empty())
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
if rest.is_empty() {
|
||||
return (None, None);
|
||||
}
|
||||
|
||||
// Look for footer markers
|
||||
let footer_markers = ["BREAKING CHANGE:", "Closes", "Fixes", "Refs", "Co-authored-by:"];
|
||||
|
||||
let mut body_lines = vec![];
|
||||
let mut footer_lines = vec![];
|
||||
let mut in_footer = false;
|
||||
|
||||
for line in &rest {
|
||||
if footer_markers.iter().any(|m| line.starts_with(m)) {
|
||||
in_footer = true;
|
||||
}
|
||||
|
||||
if in_footer {
|
||||
footer_lines.push(*line);
|
||||
} else {
|
||||
body_lines.push(*line);
|
||||
}
|
||||
}
|
||||
|
||||
let body = if body_lines.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(body_lines.join("\n"))
|
||||
};
|
||||
|
||||
let footer = if footer_lines.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(footer_lines.join("\n"))
|
||||
};
|
||||
|
||||
(body, footer)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generated commit structure
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GeneratedCommit {
|
||||
pub commit_type: String,
|
||||
pub scope: Option<String>,
|
||||
pub description: String,
|
||||
pub body: Option<String>,
|
||||
pub footer: Option<String>,
|
||||
pub breaking: bool,
|
||||
}
|
||||
|
||||
impl GeneratedCommit {
|
||||
/// Format as conventional commit
|
||||
pub fn to_conventional(&self) -> String {
|
||||
crate::utils::formatter::format_conventional_commit(
|
||||
&self.commit_type,
|
||||
self.scope.as_deref(),
|
||||
&self.description,
|
||||
self.body.as_deref(),
|
||||
self.footer.as_deref(),
|
||||
self.breaking,
|
||||
)
|
||||
}
|
||||
|
||||
/// Format as commitlint commit
|
||||
pub fn to_commitlint(&self) -> String {
|
||||
crate::utils::formatter::format_commitlint_commit(
|
||||
&self.commit_type,
|
||||
self.scope.as_deref(),
|
||||
&self.description,
|
||||
self.body.as_deref(),
|
||||
self.footer.as_deref(),
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// System prompts for LLM
|
||||
|
||||
const CONVENTIONAL_COMMIT_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates conventional commit messages.
|
||||
|
||||
Analyze the git diff provided and generate a commit message following the Conventional Commits specification.
|
||||
|
||||
Format: <type>[optional scope]: <description>
|
||||
|
||||
Types:
|
||||
- feat: A new feature
|
||||
- fix: A bug fix
|
||||
- docs: Documentation only changes
|
||||
- style: Changes that don't affect code meaning (formatting, semicolons, etc.)
|
||||
- refactor: Code change that neither fixes a bug nor adds a feature
|
||||
- perf: Code change that improves performance
|
||||
- test: Adding or correcting tests
|
||||
- build: Changes to build system or dependencies
|
||||
- ci: Changes to CI configuration
|
||||
- chore: Other changes that don't modify src or test files
|
||||
- revert: Reverts a previous commit
|
||||
|
||||
Rules:
|
||||
1. Use lowercase for type and scope
|
||||
2. Keep description under 100 characters
|
||||
3. Use imperative mood ("add" not "added")
|
||||
4. Don't capitalize first letter
|
||||
5. No period at the end
|
||||
6. Include scope if the change is specific to a module/component
|
||||
|
||||
Output ONLY the commit message, nothing else.
|
||||
"#;
|
||||
|
||||
const COMMITLINT_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates commit messages following @commitlint/config-conventional.
|
||||
|
||||
Analyze the git diff and generate a commit message.
|
||||
|
||||
Format: <type>[optional scope]: <subject>
|
||||
|
||||
Types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert
|
||||
|
||||
Rules:
|
||||
1. Subject should not start with uppercase
|
||||
2. Subject should not end with period
|
||||
3. Subject should be 4-100 characters
|
||||
4. Use imperative mood
|
||||
5. Be concise but descriptive
|
||||
|
||||
Output ONLY the commit message, nothing else.
|
||||
"#;
|
||||
|
||||
const TAG_MESSAGE_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates git tag annotation messages.
|
||||
|
||||
Given a version number and a list of commits, generate a concise but informative tag message.
|
||||
|
||||
The message should:
|
||||
1. Start with a brief summary of the release
|
||||
2. Group changes by type (features, fixes, etc.)
|
||||
3. Be suitable for a git annotated tag
|
||||
|
||||
Format:
|
||||
<version> Release
|
||||
|
||||
Summary of changes...
|
||||
|
||||
Changes:
|
||||
- Feature: description
|
||||
- Fix: description
|
||||
...
|
||||
"#;
|
||||
|
||||
const CHANGELOG_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates changelog entries.
|
||||
|
||||
Given a version and a list of commits, generate a well-formatted changelog section.
|
||||
|
||||
Group commits by:
|
||||
- Features (feat)
|
||||
- Bug Fixes (fix)
|
||||
- Documentation (docs)
|
||||
- Other Changes
|
||||
|
||||
Format in markdown with proper headings and bullet points.
|
||||
"#;
|
||||
|
||||
/// HTTP client helper
|
||||
pub(crate) fn create_http_client(timeout: Duration) -> Result<reqwest::Client> {
|
||||
reqwest::Client::builder()
|
||||
.timeout(timeout)
|
||||
.build()
|
||||
.context("Failed to create HTTP client")
|
||||
}
|
||||
206
src/llm/ollama.rs
Normal file
206
src/llm/ollama.rs
Normal file
@@ -0,0 +1,206 @@
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Ollama API client
|
||||
pub struct OllamaClient {
|
||||
base_url: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GenerateRequest {
|
||||
model: String,
|
||||
prompt: String,
|
||||
system: Option<String>,
|
||||
stream: bool,
|
||||
options: GenerationOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default)]
|
||||
struct GenerationOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
num_predict: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GenerateResponse {
|
||||
response: String,
|
||||
done: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ListModelsResponse {
|
||||
models: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ModelInfo {
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
/// Create new Ollama client
|
||||
pub fn new(base_url: &str, model: &str) -> Self {
|
||||
let client = create_http_client(Duration::from_secs(120))
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set timeout
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.client = create_http_client(timeout)
|
||||
.expect("Failed to create HTTP client");
|
||||
self
|
||||
}
|
||||
|
||||
/// List available models
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
|
||||
let response = self.client
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to list Ollama models")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Ollama API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let result: ListModelsResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Ollama response")?;
|
||||
|
||||
Ok(result.models.into_iter().map(|m| m.name).collect())
|
||||
}
|
||||
|
||||
/// Pull a model
|
||||
pub async fn pull_model(&self, model: &str) -> Result<()> {
|
||||
let url = format!("{}/api/pull", self.base_url);
|
||||
|
||||
let request = serde_json::json!({
|
||||
"name": model,
|
||||
"stream": false,
|
||||
});
|
||||
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to pull Ollama model")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Ollama pull error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if model exists
|
||||
pub async fn model_exists(&self, model: &str) -> bool {
|
||||
match self.list_models().await {
|
||||
Ok(models) => models.contains(&model.to_string()),
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OllamaClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
self.generate_with_system("", prompt).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let url = format!("{}/api/generate", self.base_url);
|
||||
|
||||
let system = if system.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(system.to_string())
|
||||
};
|
||||
|
||||
let request = GenerateRequest {
|
||||
model: self.model.clone(),
|
||||
prompt: user.to_string(),
|
||||
system,
|
||||
stream: false,
|
||||
options: GenerationOptions {
|
||||
temperature: Some(0.7),
|
||||
num_predict: Some(500),
|
||||
},
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to Ollama")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Ollama API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let result: GenerateResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Ollama response")?;
|
||||
|
||||
Ok(result.response.trim().to_string())
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
|
||||
match self.client.get(&url).send().await {
|
||||
Ok(response) => response.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"ollama"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// These tests require a running Ollama server
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_ollama_connection() {
|
||||
let client = OllamaClient::new("http://localhost:11434", "llama2");
|
||||
assert!(client.is_available().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_ollama_generate() {
|
||||
let client = OllamaClient::new("http://localhost:11434", "llama2");
|
||||
let response = client.generate("Hello, how are you?").await;
|
||||
assert!(response.is_ok());
|
||||
println!("Response: {}", response.unwrap());
|
||||
}
|
||||
}
|
||||
345
src/llm/openai.rs
Normal file
345
src/llm/openai.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
use super::{create_http_client, LlmProvider};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
/// OpenAI API client
|
||||
pub struct OpenAiClient {
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: ApiError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiError {
|
||||
message: String,
|
||||
#[serde(rename = "type")]
|
||||
error_type: String,
|
||||
}
|
||||
|
||||
impl OpenAiClient {
|
||||
/// Create new OpenAI client
|
||||
pub fn new(base_url: &str, api_key: &str, model: &str) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
Ok(Self {
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
client,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set timeout
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
|
||||
self.client = create_http_client(timeout)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// List available models
|
||||
pub async fn list_models(&self) -> Result<Vec<String>> {
|
||||
let url = format!("{}/models", self.base_url);
|
||||
|
||||
let response = self.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to list OpenAI models")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
bail!("OpenAI API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModelsResponse {
|
||||
data: Vec<Model>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Model {
|
||||
id: String,
|
||||
}
|
||||
|
||||
let result: ModelsResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse OpenAI response")?;
|
||||
|
||||
Ok(result.data.into_iter().map(|m| m.id).collect())
|
||||
}
|
||||
|
||||
/// Validate API key
|
||||
pub async fn validate_key(&self) -> Result<bool> {
|
||||
match self.list_models().await {
|
||||
Ok(_) => Ok(true),
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
if err_str.contains("401") || err_str.contains("Unauthorized") {
|
||||
Ok(false)
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenAiClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let mut messages = vec![];
|
||||
|
||||
if !system.is_empty() {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: system.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
});
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
self.validate_key().await.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"openai"
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAiClient {
|
||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
max_tokens: Some(500),
|
||||
temperature: Some(0.7),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to OpenAI")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
|
||||
// Try to parse error
|
||||
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
bail!("OpenAI API error: {} ({})", error.error.message, error.error.error_type);
|
||||
}
|
||||
|
||||
bail!("OpenAI API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let result: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse OpenAI response")?;
|
||||
|
||||
result.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content.trim().to_string())
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Azure OpenAI client (extends OpenAI with Azure-specific config)
|
||||
pub struct AzureOpenAiClient {
|
||||
endpoint: String,
|
||||
api_key: String,
|
||||
deployment: String,
|
||||
api_version: String,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl AzureOpenAiClient {
|
||||
/// Create new Azure OpenAI client
|
||||
pub fn new(
|
||||
endpoint: &str,
|
||||
api_key: &str,
|
||||
deployment: &str,
|
||||
api_version: &str,
|
||||
) -> Result<Self> {
|
||||
let client = create_http_client(Duration::from_secs(60))?;
|
||||
|
||||
Ok(Self {
|
||||
endpoint: endpoint.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.to_string(),
|
||||
deployment: deployment.to_string(),
|
||||
api_version: api_version.to_string(),
|
||||
client,
|
||||
})
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
||||
self.endpoint, self.deployment, self.api_version
|
||||
);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.deployment.clone(),
|
||||
messages,
|
||||
max_tokens: Some(500),
|
||||
temperature: Some(0.7),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
.header("api-key", &self.api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send request to Azure OpenAI")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
bail!("Azure OpenAI API error: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let result: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Azure OpenAI response")?;
|
||||
|
||||
result.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content.trim().to_string())
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for AzureOpenAiClient {
|
||||
async fn generate(&self, prompt: &str) -> Result<String> {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
|
||||
let mut messages = vec![];
|
||||
|
||||
if !system.is_empty() {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: system.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: user.to_string(),
|
||||
});
|
||||
|
||||
self.chat_completion(messages).await
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
// Simple check - try to make a minimal request
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
||||
self.endpoint, self.deployment, self.api_version
|
||||
);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.deployment.clone(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi".to_string(),
|
||||
}],
|
||||
max_tokens: Some(5),
|
||||
temperature: Some(0.0),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
match self.client
|
||||
.post(&url)
|
||||
.header("api-key", &self.api_key)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => response.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"azure-openai"
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user