feat:(first commit)created repository and complete 0.1.0

This commit is contained in:
2026-01-30 14:18:32 +08:00
commit 5d4156e5e0
36 changed files with 8686 additions and 0 deletions

227
src/llm/anthropic.rs Normal file
View File

@@ -0,0 +1,227 @@
use super::{create_http_client, LlmProvider};
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Duration;
/// Anthropic Claude API client
pub struct AnthropicClient {
api_key: String,
model: String,
client: reqwest::Client,
}
#[derive(Debug, Serialize)]
struct MessagesRequest {
model: String,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
messages: Vec<AnthropicMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicMessage {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct MessagesResponse {
content: Vec<ContentBlock>,
}
#[derive(Debug, Deserialize)]
struct ContentBlock {
#[serde(rename = "type")]
content_type: String,
text: String,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: AnthropicError,
}
#[derive(Debug, Deserialize)]
struct AnthropicError {
#[serde(rename = "type")]
error_type: String,
message: String,
}
impl AnthropicClient {
/// Create new Anthropic client
pub fn new(api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
Ok(Self {
api_key: api_key.to_string(),
model: model.to_string(),
client,
})
}
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?;
Ok(self)
}
/// Validate API key
pub async fn validate_key(&self) -> Result<bool> {
let url = "https://api.anthropic.com/v1/messages";
let request = MessagesRequest {
model: self.model.clone(),
max_tokens: 5,
temperature: Some(0.0),
messages: vec![AnthropicMessage {
role: "user".to_string(),
content: "Hi".to_string(),
}],
system: None,
};
let response = self.client
.post(url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.json(&request)
.send()
.await;
match response {
Ok(resp) => {
if resp.status().is_success() {
Ok(true)
} else {
let status = resp.status();
if status.as_u16() == 401 {
Ok(false)
} else {
let text = resp.text().await.unwrap_or_default();
bail!("Anthropic API error: {} - {}", status, text)
}
}
}
Err(e) => Err(e.into()),
}
}
}
#[async_trait]
impl LlmProvider for AnthropicClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![AnthropicMessage {
role: "user".to_string(),
content: prompt.to_string(),
}];
self.messages_request(messages, None).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let messages = vec![AnthropicMessage {
role: "user".to_string(),
content: user.to_string(),
}];
let system = if system.is_empty() {
None
} else {
Some(system.to_string())
};
self.messages_request(messages, system).await
}
async fn is_available(&self) -> bool {
self.validate_key().await.unwrap_or(false)
}
fn name(&self) -> &str {
"anthropic"
}
}
impl AnthropicClient {
async fn messages_request(
&self,
messages: Vec<AnthropicMessage>,
system: Option<String>,
) -> Result<String> {
let url = "https://api.anthropic.com/v1/messages";
let request = MessagesRequest {
model: self.model.clone(),
max_tokens: 500,
temperature: Some(0.7),
messages,
system,
};
let response = self.client
.post(url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.context("Failed to send request to Anthropic")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
// Try to parse error
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!("Anthropic API error: {} ({})", error.error.message, error.error.error_type);
}
bail!("Anthropic API error: {} - {}", status, text);
}
let result: MessagesResponse = response
.json()
.await
.context("Failed to parse Anthropic response")?;
result.content
.into_iter()
.find(|c| c.content_type == "text")
.map(|c| c.text.trim().to_string())
.ok_or_else(|| anyhow::anyhow!("No text response from Anthropic"))
}
}
/// Available Anthropic models
pub const ANTHROPIC_MODELS: &[&str] = &[
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
"claude-2.1",
"claude-2.0",
"claude-instant-1.2",
];
/// Check if a model name is valid
pub fn is_valid_model(model: &str) -> bool {
ANTHROPIC_MODELS.contains(&model)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_validation() {
assert!(is_valid_model("claude-3-sonnet-20240229"));
assert!(!is_valid_model("invalid-model"));
}
}

433
src/llm/mod.rs Normal file
View File

@@ -0,0 +1,433 @@
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Duration;
pub mod ollama;
pub mod openai;
pub mod anthropic;
pub use ollama::OllamaClient;
pub use openai::OpenAiClient;
pub use anthropic::AnthropicClient;
/// LLM provider trait
#[async_trait]
pub trait LlmProvider: Send + Sync {
/// Generate text from prompt
async fn generate(&self, prompt: &str) -> Result<String>;
/// Generate with system prompt
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String>;
/// Check if provider is available
async fn is_available(&self) -> bool;
/// Get provider name
fn name(&self) -> &str;
}
/// LLM client that wraps different providers
pub struct LlmClient {
provider: Box<dyn LlmProvider>,
config: LlmClientConfig,
}
#[derive(Debug, Clone)]
pub struct LlmClientConfig {
pub max_tokens: u32,
pub temperature: f32,
pub timeout: Duration,
}
impl Default for LlmClientConfig {
fn default() -> Self {
Self {
max_tokens: 500,
temperature: 0.7,
timeout: Duration::from_secs(30),
}
}
}
impl LlmClient {
/// Create LLM client from configuration
pub async fn from_config(config: &crate::config::LlmConfig) -> Result<Self> {
let client_config = LlmClientConfig {
max_tokens: config.max_tokens,
temperature: config.temperature,
timeout: Duration::from_secs(config.timeout),
};
let provider: Box<dyn LlmProvider> = match config.provider.as_str() {
"ollama" => {
Box::new(OllamaClient::new(&config.ollama.url, &config.ollama.model))
}
"openai" => {
let api_key = config.openai.api_key.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenAI API key not configured"))?;
Box::new(OpenAiClient::new(
&config.openai.base_url,
api_key,
&config.openai.model,
)?)
}
"anthropic" => {
let api_key = config.anthropic.api_key.as_ref()
.ok_or_else(|| anyhow::anyhow!("Anthropic API key not configured"))?;
Box::new(AnthropicClient::new(api_key, &config.anthropic.model)?)
}
_ => bail!("Unknown LLM provider: {}", config.provider),
};
Ok(Self {
provider,
config: client_config,
})
}
/// Create with specific provider
pub fn with_provider(provider: Box<dyn LlmProvider>) -> Self {
Self {
provider,
config: LlmClientConfig::default(),
}
}
/// Generate commit message from git diff
pub async fn generate_commit_message(
&self,
diff: &str,
format: crate::config::CommitFormat,
) -> Result<GeneratedCommit> {
let system_prompt = match format {
crate::config::CommitFormat::Conventional => {
CONVENTIONAL_COMMIT_SYSTEM_PROMPT
}
crate::config::CommitFormat::Commitlint => {
COMMITLINT_SYSTEM_PROMPT
}
};
let prompt = format!("{}", diff);
let response = self.provider.generate_with_system(system_prompt, &prompt).await?;
self.parse_commit_response(&response, format)
}
/// Generate tag message from commits
pub async fn generate_tag_message(
&self,
version: &str,
commits: &[String],
) -> Result<String> {
let system_prompt = TAG_MESSAGE_SYSTEM_PROMPT;
let commits_text = commits.join("\n");
let prompt = format!("Version: {}\n\nCommits:\n{}", version, commits_text);
self.provider.generate_with_system(system_prompt, &prompt).await
}
/// Generate changelog entry
pub async fn generate_changelog_entry(
&self,
version: &str,
commits: &[(String, String)], // (type, message)
) -> Result<String> {
let system_prompt = CHANGELOG_SYSTEM_PROMPT;
let commits_text = commits
.iter()
.map(|(t, m)| format!("- [{}] {}", t, m))
.collect::<Vec<_>>()
.join("\n");
let prompt = format!("Version: {}\n\nCommits:\n{}", version, commits_text);
self.provider.generate_with_system(system_prompt, &prompt).await
}
/// Check if provider is available
pub async fn is_available(&self) -> bool {
self.provider.is_available().await
}
/// Parse commit response from LLM
fn parse_commit_response(&self, response: &str, format: crate::config::CommitFormat) -> Result<GeneratedCommit> {
let lines: Vec<&str> = response.lines().collect();
if lines.is_empty() {
bail!("Empty response from LLM");
}
let first_line = lines[0];
// Parse based on format
match format {
crate::config::CommitFormat::Conventional => {
self.parse_conventional_commit(first_line, lines)
}
crate::config::CommitFormat::Commitlint => {
self.parse_commitlint_commit(first_line, lines)
}
}
}
fn parse_conventional_commit(
&self,
first_line: &str,
lines: Vec<&str>,
) -> Result<GeneratedCommit> {
// Parse: type(scope)!: description
let parts: Vec<&str> = first_line.splitn(2, ':').collect();
if parts.len() != 2 {
bail!("Invalid conventional commit format: missing colon");
}
let type_part = parts[0];
let description = parts[1].trim();
// Extract type, scope, and breaking indicator
let breaking = type_part.ends_with('!');
let type_part = type_part.trim_end_matches('!');
let (commit_type, scope) = if let Some(start) = type_part.find('(') {
if let Some(end) = type_part.find(')') {
let t = &type_part[..start];
let s = &type_part[start + 1..end];
(t.to_string(), Some(s.to_string()))
} else {
bail!("Invalid scope format: missing closing parenthesis");
}
} else {
(type_part.to_string(), None)
};
// Extract body and footer
let (body, footer) = self.extract_body_footer(&lines);
Ok(GeneratedCommit {
commit_type,
scope,
description: description.to_string(),
body,
footer,
breaking,
})
}
fn parse_commitlint_commit(
&self,
first_line: &str,
lines: Vec<&str>,
) -> Result<GeneratedCommit> {
// Similar parsing but with commitlint rules
let parts: Vec<&str> = first_line.splitn(2, ':').collect();
if parts.len() != 2 {
bail!("Invalid commit format: missing colon");
}
let type_part = parts[0];
let subject = parts[1].trim();
let (commit_type, scope) = if let Some(start) = type_part.find('(') {
if let Some(end) = type_part.find(')') {
let t = &type_part[..start];
let s = &type_part[start + 1..end];
(t.to_string(), Some(s.to_string()))
} else {
(type_part.to_string(), None)
}
} else {
(type_part.to_string(), None)
};
let (body, footer) = self.extract_body_footer(&lines);
Ok(GeneratedCommit {
commit_type,
scope,
description: subject.to_string(),
body,
footer,
breaking: false,
})
}
fn extract_body_footer(&self, lines: &[&str]) -> (Option<String>, Option<String>) {
if lines.len() <= 1 {
return (None, None);
}
let rest: Vec<&str> = lines[1..]
.iter()
.skip_while(|l| l.trim().is_empty())
.copied()
.collect();
if rest.is_empty() {
return (None, None);
}
// Look for footer markers
let footer_markers = ["BREAKING CHANGE:", "Closes", "Fixes", "Refs", "Co-authored-by:"];
let mut body_lines = vec![];
let mut footer_lines = vec![];
let mut in_footer = false;
for line in &rest {
if footer_markers.iter().any(|m| line.starts_with(m)) {
in_footer = true;
}
if in_footer {
footer_lines.push(*line);
} else {
body_lines.push(*line);
}
}
let body = if body_lines.is_empty() {
None
} else {
Some(body_lines.join("\n"))
};
let footer = if footer_lines.is_empty() {
None
} else {
Some(footer_lines.join("\n"))
};
(body, footer)
}
}
/// Generated commit structure
#[derive(Debug, Clone)]
pub struct GeneratedCommit {
pub commit_type: String,
pub scope: Option<String>,
pub description: String,
pub body: Option<String>,
pub footer: Option<String>,
pub breaking: bool,
}
impl GeneratedCommit {
/// Format as conventional commit
pub fn to_conventional(&self) -> String {
crate::utils::formatter::format_conventional_commit(
&self.commit_type,
self.scope.as_deref(),
&self.description,
self.body.as_deref(),
self.footer.as_deref(),
self.breaking,
)
}
/// Format as commitlint commit
pub fn to_commitlint(&self) -> String {
crate::utils::formatter::format_commitlint_commit(
&self.commit_type,
self.scope.as_deref(),
&self.description,
self.body.as_deref(),
self.footer.as_deref(),
None,
)
}
}
// System prompts for LLM
const CONVENTIONAL_COMMIT_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates conventional commit messages.
Analyze the git diff provided and generate a commit message following the Conventional Commits specification.
Format: <type>[optional scope]: <description>
Types:
- feat: A new feature
- fix: A bug fix
- docs: Documentation only changes
- style: Changes that don't affect code meaning (formatting, semicolons, etc.)
- refactor: Code change that neither fixes a bug nor adds a feature
- perf: Code change that improves performance
- test: Adding or correcting tests
- build: Changes to build system or dependencies
- ci: Changes to CI configuration
- chore: Other changes that don't modify src or test files
- revert: Reverts a previous commit
Rules:
1. Use lowercase for type and scope
2. Keep description under 100 characters
3. Use imperative mood ("add" not "added")
4. Don't capitalize first letter
5. No period at the end
6. Include scope if the change is specific to a module/component
Output ONLY the commit message, nothing else.
"#;
const COMMITLINT_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates commit messages following @commitlint/config-conventional.
Analyze the git diff and generate a commit message.
Format: <type>[optional scope]: <subject>
Types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert
Rules:
1. Subject should not start with uppercase
2. Subject should not end with period
3. Subject should be 4-100 characters
4. Use imperative mood
5. Be concise but descriptive
Output ONLY the commit message, nothing else.
"#;
const TAG_MESSAGE_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates git tag annotation messages.
Given a version number and a list of commits, generate a concise but informative tag message.
The message should:
1. Start with a brief summary of the release
2. Group changes by type (features, fixes, etc.)
3. Be suitable for a git annotated tag
Format:
<version> Release
Summary of changes...
Changes:
- Feature: description
- Fix: description
...
"#;
const CHANGELOG_SYSTEM_PROMPT: &str = r#"You are a helpful assistant that generates changelog entries.
Given a version and a list of commits, generate a well-formatted changelog section.
Group commits by:
- Features (feat)
- Bug Fixes (fix)
- Documentation (docs)
- Other Changes
Format in markdown with proper headings and bullet points.
"#;
/// HTTP client helper
pub(crate) fn create_http_client(timeout: Duration) -> Result<reqwest::Client> {
reqwest::Client::builder()
.timeout(timeout)
.build()
.context("Failed to create HTTP client")
}

206
src/llm/ollama.rs Normal file
View File

@@ -0,0 +1,206 @@
use super::{create_http_client, LlmProvider};
use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Duration;
/// Ollama API client
pub struct OllamaClient {
base_url: String,
model: String,
client: reqwest::Client,
}
#[derive(Debug, Serialize)]
struct GenerateRequest {
model: String,
prompt: String,
system: Option<String>,
stream: bool,
options: GenerationOptions,
}
#[derive(Debug, Serialize, Default)]
struct GenerationOptions {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<u32>,
}
#[derive(Debug, Deserialize)]
struct GenerateResponse {
response: String,
done: bool,
}
#[derive(Debug, Deserialize)]
struct ListModelsResponse {
models: Vec<ModelInfo>,
}
#[derive(Debug, Deserialize)]
struct ModelInfo {
name: String,
}
impl OllamaClient {
/// Create new Ollama client
pub fn new(base_url: &str, model: &str) -> Self {
let client = create_http_client(Duration::from_secs(120))
.expect("Failed to create HTTP client");
Self {
base_url: base_url.trim_end_matches('/').to_string(),
model: model.to_string(),
client,
}
}
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.client = create_http_client(timeout)
.expect("Failed to create HTTP client");
self
}
/// List available models
pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/api/tags", self.base_url);
let response = self.client
.get(&url)
.send()
.await
.context("Failed to list Ollama models")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
anyhow::bail!("Ollama API error: {} - {}", status, text);
}
let result: ListModelsResponse = response
.json()
.await
.context("Failed to parse Ollama response")?;
Ok(result.models.into_iter().map(|m| m.name).collect())
}
/// Pull a model
pub async fn pull_model(&self, model: &str) -> Result<()> {
let url = format!("{}/api/pull", self.base_url);
let request = serde_json::json!({
"name": model,
"stream": false,
});
let response = self.client
.post(&url)
.json(&request)
.send()
.await
.context("Failed to pull Ollama model")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
anyhow::bail!("Ollama pull error: {} - {}", status, text);
}
Ok(())
}
/// Check if model exists
pub async fn model_exists(&self, model: &str) -> bool {
match self.list_models().await {
Ok(models) => models.contains(&model.to_string()),
Err(_) => false,
}
}
}
#[async_trait]
impl LlmProvider for OllamaClient {
async fn generate(&self, prompt: &str) -> Result<String> {
self.generate_with_system("", prompt).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let url = format!("{}/api/generate", self.base_url);
let system = if system.is_empty() {
None
} else {
Some(system.to_string())
};
let request = GenerateRequest {
model: self.model.clone(),
prompt: user.to_string(),
system,
stream: false,
options: GenerationOptions {
temperature: Some(0.7),
num_predict: Some(500),
},
};
let response = self.client
.post(&url)
.json(&request)
.send()
.await
.context("Failed to send request to Ollama")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
anyhow::bail!("Ollama API error: {} - {}", status, text);
}
let result: GenerateResponse = response
.json()
.await
.context("Failed to parse Ollama response")?;
Ok(result.response.trim().to_string())
}
async fn is_available(&self) -> bool {
let url = format!("{}/api/tags", self.base_url);
match self.client.get(&url).send().await {
Ok(response) => response.status().is_success(),
Err(_) => false,
}
}
fn name(&self) -> &str {
"ollama"
}
}
#[cfg(test)]
mod tests {
use super::*;
// These tests require a running Ollama server
#[tokio::test]
#[ignore]
async fn test_ollama_connection() {
let client = OllamaClient::new("http://localhost:11434", "llama2");
assert!(client.is_available().await);
}
#[tokio::test]
#[ignore]
async fn test_ollama_generate() {
let client = OllamaClient::new("http://localhost:11434", "llama2");
let response = client.generate("Hello, how are you?").await;
assert!(response.is_ok());
println!("Response: {}", response.unwrap());
}
}

345
src/llm/openai.rs Normal file
View File

@@ -0,0 +1,345 @@
use super::{create_http_client, LlmProvider};
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Duration;
/// OpenAI API client
pub struct OpenAiClient {
base_url: String,
api_key: String,
model: String,
client: reqwest::Client,
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: Message,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: ApiError,
}
#[derive(Debug, Deserialize)]
struct ApiError {
message: String,
#[serde(rename = "type")]
error_type: String,
}
impl OpenAiClient {
/// Create new OpenAI client
pub fn new(base_url: &str, api_key: &str, model: &str) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
Ok(Self {
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
client,
})
}
/// Set timeout
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.client = create_http_client(timeout)?;
Ok(self)
}
/// List available models
pub async fn list_models(&self) -> Result<Vec<String>> {
let url = format!("{}/models", self.base_url);
let response = self.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
.context("Failed to list OpenAI models")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
bail!("OpenAI API error: {} - {}", status, text);
}
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<Model>,
}
#[derive(Deserialize)]
struct Model {
id: String,
}
let result: ModelsResponse = response
.json()
.await
.context("Failed to parse OpenAI response")?;
Ok(result.data.into_iter().map(|m| m.id).collect())
}
/// Validate API key
pub async fn validate_key(&self) -> Result<bool> {
match self.list_models().await {
Ok(_) => Ok(true),
Err(e) => {
let err_str = e.to_string();
if err_str.contains("401") || err_str.contains("Unauthorized") {
Ok(false)
} else {
Err(e)
}
}
}
}
}
#[async_trait]
impl LlmProvider for OpenAiClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![
Message {
role: "user".to_string(),
content: prompt.to_string(),
},
];
self.chat_completion(messages).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![];
if !system.is_empty() {
messages.push(Message {
role: "system".to_string(),
content: system.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: user.to_string(),
});
self.chat_completion(messages).await
}
async fn is_available(&self) -> bool {
self.validate_key().await.unwrap_or(false)
}
fn name(&self) -> &str {
"openai"
}
}
impl OpenAiClient {
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
let request = ChatCompletionRequest {
model: self.model.clone(),
messages,
max_tokens: Some(500),
temperature: Some(0.7),
stream: false,
};
let response = self.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.context("Failed to send request to OpenAI")?;
let status = response.status();
if !status.is_success() {
let text = response.text().await.unwrap_or_default();
// Try to parse error
if let Ok(error) = serde_json::from_str::<ErrorResponse>(&text) {
bail!("OpenAI API error: {} ({})", error.error.message, error.error.error_type);
}
bail!("OpenAI API error: {} - {}", status, text);
}
let result: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse OpenAI response")?;
result.choices
.into_iter()
.next()
.map(|c| c.message.content.trim().to_string())
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
}
}
/// Azure OpenAI client (extends OpenAI with Azure-specific config)
pub struct AzureOpenAiClient {
endpoint: String,
api_key: String,
deployment: String,
api_version: String,
client: reqwest::Client,
}
impl AzureOpenAiClient {
/// Create new Azure OpenAI client
pub fn new(
endpoint: &str,
api_key: &str,
deployment: &str,
api_version: &str,
) -> Result<Self> {
let client = create_http_client(Duration::from_secs(60))?;
Ok(Self {
endpoint: endpoint.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
deployment: deployment.to_string(),
api_version: api_version.to_string(),
client,
})
}
async fn chat_completion(&self, messages: Vec<Message>) -> Result<String> {
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
self.endpoint, self.deployment, self.api_version
);
let request = ChatCompletionRequest {
model: self.deployment.clone(),
messages,
max_tokens: Some(500),
temperature: Some(0.7),
stream: false,
};
let response = self.client
.post(&url)
.header("api-key", &self.api_key)
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.context("Failed to send request to Azure OpenAI")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
bail!("Azure OpenAI API error: {} - {}", status, text);
}
let result: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse Azure OpenAI response")?;
result.choices
.into_iter()
.next()
.map(|c| c.message.content.trim().to_string())
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))
}
}
#[async_trait]
impl LlmProvider for AzureOpenAiClient {
async fn generate(&self, prompt: &str) -> Result<String> {
let messages = vec![
Message {
role: "user".to_string(),
content: prompt.to_string(),
},
];
self.chat_completion(messages).await
}
async fn generate_with_system(&self, system: &str, user: &str) -> Result<String> {
let mut messages = vec![];
if !system.is_empty() {
messages.push(Message {
role: "system".to_string(),
content: system.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: user.to_string(),
});
self.chat_completion(messages).await
}
async fn is_available(&self) -> bool {
// Simple check - try to make a minimal request
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
self.endpoint, self.deployment, self.api_version
);
let request = ChatCompletionRequest {
model: self.deployment.clone(),
messages: vec![Message {
role: "user".to_string(),
content: "Hi".to_string(),
}],
max_tokens: Some(5),
temperature: Some(0.0),
stream: false,
};
match self.client
.post(&url)
.header("api-key", &self.api_key)
.json(&request)
.send()
.await
{
Ok(response) => response.status().is_success(),
Err(_) => false,
}
}
fn name(&self) -> &str {
"azure-openai"
}
}