llm-worker-rs/worker/src/llm/ollama.rs

802 lines
30 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::{
LlmClientTrait, WorkerError,
types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall},
url_config::UrlConfig,
};
use futures_util::{Stream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
// --- Request & Response Structures ---
#[derive(Debug, Serialize, Clone)]
pub struct OllamaRequest {
pub model: String,
pub messages: Vec<OllamaMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<OllamaTool>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OllamaMessage {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<OllamaToolCall>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OllamaToolCall {
pub function: OllamaToolCallFunction,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OllamaToolCallFunction {
pub name: String,
#[serde(
serialize_with = "serialize_arguments",
deserialize_with = "deserialize_arguments"
)]
pub arguments: String,
}
/// Custom serializer for arguments field that serializes strings as-is
fn serialize_arguments<S>(arguments: &str, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
// Try to parse as JSON first, if successful serialize as raw JSON
// If not valid JSON, serialize as string
if let Ok(value) = serde_json::from_str::<serde_json::Value>(arguments) {
value.serialize(serializer)
} else {
arguments.serialize(serializer)
}
}
/// Custom deserializer for arguments field that handles both string and object formats
fn deserialize_arguments<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let value: serde_json::Value = serde::Deserialize::deserialize(deserializer)?;
match value {
serde_json::Value::String(s) => Ok(s),
serde_json::Value::Object(_) | serde_json::Value::Array(_) => {
// If it's an object or array, serialize it back to a JSON string
serde_json::to_string(&value).map_err(D::Error::custom)
}
_ => Err(D::Error::custom("arguments must be a string or object")),
}
}
#[derive(Debug, Serialize, Clone)]
pub struct OllamaTool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: OllamaFunction,
}
#[derive(Debug, Serialize, Clone)]
pub struct OllamaFunction {
pub name: String,
pub description: String,
pub parameters: Value,
}
#[derive(Debug, Deserialize)]
pub struct OllamaResponse {
pub message: OllamaMessage,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct OllamaStreamResponse {
pub message: OllamaMessage,
pub done: bool,
}
#[derive(Debug, Deserialize)]
pub struct OllamaModelShowResponse {
pub details: Option<OllamaModelDetails>,
pub model_info: Option<serde_json::Value>,
pub template: Option<String>,
pub system: Option<String>,
pub parameters: Option<serde_json::Value>,
}
#[derive(Debug, Deserialize)]
pub struct OllamaModelDetails {
pub format: Option<String>,
pub family: Option<String>,
pub families: Option<Vec<String>>,
pub parameter_size: Option<String>,
pub quantization_level: Option<String>,
}
// --- Client ---
pub struct OllamaClient {
model: String,
base_url: String,
api_key: Option<String>,
}
impl OllamaClient {
pub fn new(model: &str) -> Self {
Self {
model: model.to_string(),
base_url: UrlConfig::get_base_url("ollama"),
api_key: None,
}
}
pub fn new_with_key(api_key: &str, model: &str) -> Self {
tracing::debug!(
"Ollama: Creating client with API key (length: {}), model: {}",
api_key.len(),
model
);
Self {
model: model.to_string(),
base_url: UrlConfig::get_base_url("ollama"),
api_key: Some(api_key.to_string()),
}
}
pub fn get_model_name(&self) -> String {
self.model.clone()
}
fn add_auth_header(&self, request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
tracing::debug!(
"Ollama: add_auth_header called, api_key present: {}",
self.api_key.is_some()
);
if let Some(ref api_key) = self.api_key {
// API key詳細ログは削除セキュリティと見づらさ解消のため
// API keyが空でない場合のみヘッダーを追加
if !api_key.trim().is_empty() {
// API keyがすでにフォーマットされているかチェック
if api_key.starts_with("Basic ") || api_key.starts_with("Bearer ") {
// すでにフォーマット済み(例: "Basic base64string" や "Bearer token"
// Auth header詳細ログは削除セキュリティと見づらさ解消のため
request_builder.header("Authorization", api_key)
} else {
// URLに基づいて認証方式を決定
let auth_header = if self.base_url.contains("ollama.com") {
// ollama.comの場合はBearerトークンを使用
format!("Bearer {}", api_key)
} else {
// その他の場合はBasic認証を使用ローカル/プロキシ向け)
format!("Basic {}", api_key)
};
// Auth header詳細ログは削除セキュリティと見づらさ解消のため
request_builder.header("Authorization", auth_header)
}
} else {
tracing::debug!("Ollama: Empty API key, skipping auth header");
request_builder
}
} else {
tracing::debug!(
"Ollama: No API key provided, using unauthenticated request (typical for local Ollama)"
);
request_builder
}
}
/// 静的メソッドOllamaサーバーからモデル一覧を取得デフォルトURL使用
pub async fn list_models_static(
api_key: &str,
) -> Result<Vec<crate::types::ModelInfo>, WorkerError> {
let client = Client::new();
let url = UrlConfig::get_models_url("ollama");
tracing::debug!("Ollama list_models_static requesting: {}", url);
let mut request_builder = client.get(&url);
if !api_key.trim().is_empty() {
// API keyがすでにフォーマットされているかチェック
if api_key.starts_with("Basic ") || api_key.starts_with("Bearer ") {
// すでにフォーマット済み
request_builder = request_builder.header("Authorization", api_key);
} else {
// URLに基づいて認証方式を決定
let auth_header = if url.contains("ollama.com") {
// ollama.comの場合はBearerトークンを使用
format!("Bearer {}", api_key)
} else {
// その他の場合はBasic認証を使用ローカル/プロキシ向け)
format!("Basic {}", api_key)
};
// Auth header詳細ログは削除セキュリティと見づらさ解消のため
request_builder = request_builder.header("Authorization", auth_header);
}
}
let response = request_builder.send().await.map_err(|e| {
tracing::error!("Ollama API request failed: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
let status = response.status();
tracing::info!("Ollama list_models_static response status: {}", status);
if !status.is_success() {
let error_body = response.text().await.unwrap_or_default();
tracing::error!(
"Ollama list_models_static failed - Status: {}, Body: {}",
status,
error_body
);
let error_msg = format!("Failed to list Ollama models: {} - {}", status, error_body);
return Err(WorkerError::from_api_error(
error_msg,
&crate::types::LlmProvider::Ollama,
));
}
let response_text = response.text().await.map_err(|e| {
tracing::error!("Failed to read Ollama response text: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
// Raw response詳細ログは削除見づらさ解消のため
let models_response: serde_json::Value =
serde_json::from_str(&response_text).map_err(|e| {
tracing::error!(
"Failed to parse Ollama JSON response: {} - Response: {}",
e,
response_text
);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
let mut models = Vec::new();
if let Some(models_array) = models_response.get("models").and_then(|m| m.as_array()) {
for model in models_array {
if let Some(name) = model.get("name").and_then(|n| n.as_str()) {
models.push(crate::types::ModelInfo {
id: name.to_string(),
name: name.to_string(),
provider: crate::types::LlmProvider::Ollama,
supports_tools: true, // Will be determined by config
supports_function_calling: true,
supports_vision: false, // Will be determined by config
supports_multimodal: false,
context_length: None,
training_cutoff: None,
capabilities: vec!["text_generation".to_string()],
description: Some(format!("Ollama model: {}", name)),
});
}
}
}
tracing::info!(
"Ollama list_models_static found {} models with metadata",
models.len()
);
Ok(models)
}
// list_models_with_info was removed - models should be configured in models.yaml
// This private method is kept for future reference if needed
#[allow(dead_code)]
async fn list_models_with_info_internal(
&self,
) -> Result<Vec<crate::types::ModelInfo>, WorkerError> {
let client = Client::new();
let url = format!("{}/api/tags", self.base_url);
tracing::debug!("Ollama list_models requesting: {}", url);
let request = self.add_auth_header(client.get(&url));
tracing::debug!("Ollama list_models_with_info sending request to: {}", &url);
let response = request.send().await.map_err(|e| {
tracing::error!("Ollama API request failed: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
let status = response.status();
tracing::info!("Ollama list_models response status: {}", status);
if !status.is_success() {
let error_body = response.text().await.unwrap_or_default();
tracing::error!(
"Ollama list_models failed - Status: {}, Body: {}, URL: {}",
status,
error_body,
&url
);
let error_msg = format!("Failed to list Ollama models: {} - {}", status, error_body);
return Err(WorkerError::from_api_error(
error_msg,
&crate::types::LlmProvider::Ollama,
));
}
let response_text = response.text().await.map_err(|e| {
tracing::error!("Failed to read Ollama response text: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
// Raw response詳細ログは削除見づらさ解消のため
let models_response: serde_json::Value =
serde_json::from_str(&response_text).map_err(|e| {
tracing::error!(
"Failed to parse Ollama JSON response: {} - Response: {}",
e,
response_text
);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
let model_names: Vec<String> = models_response
.get("models")
.and_then(|models| models.as_array())
.ok_or_else(|| {
tracing::error!("Invalid Ollama models response format - missing 'models' array");
WorkerError::LlmApiError("Invalid models response format".to_string())
})?
.iter()
.filter_map(|model| {
model
.get("name")
.and_then(|name| name.as_str())
.map(|s| s.to_string())
})
.collect();
// Process models concurrently to get detailed information
let mut models = Vec::new();
for name in model_names {
models.push(crate::types::ModelInfo {
id: name.clone(),
name: name.clone(),
provider: crate::types::LlmProvider::Ollama,
supports_tools: true, // Will be determined by config
supports_function_calling: true,
supports_vision: false, // Will be determined by config
supports_multimodal: false,
context_length: None,
training_cutoff: None,
capabilities: vec!["text_generation".to_string()],
description: Some(format!("Ollama model: {}", name)),
});
}
tracing::info!(
"Ollama list_models found {} models with dynamic capability detection",
models.len()
);
Ok(models)
}
}
use async_stream::stream;
impl OllamaClient {
pub async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[crate::types::DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
let client = Client::new();
let url = format!("{}/api/chat", self.base_url);
let ollama_messages: Vec<OllamaMessage> = messages
.into_iter()
.map(|msg| {
// Convert tool calls if present
let tool_calls = msg.tool_calls.map(|calls| {
calls
.into_iter()
.map(|call| OllamaToolCall {
function: OllamaToolCallFunction {
name: call.name,
arguments: call.arguments,
},
})
.collect()
});
OllamaMessage {
role: match msg.role {
Role::User => "user".to_string(),
Role::Model => "assistant".to_string(),
Role::System => "system".to_string(),
Role::Tool => "tool".to_string(),
},
content: msg.content,
tool_calls,
}
})
.collect();
// Convert tools to Ollama format (similar to OpenAI)
let ollama_tools = tools.map(|tools| {
tools
.iter()
.map(|tool| OllamaTool {
tool_type: "function".to_string(),
function: OllamaFunction {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters_schema.clone(),
},
})
.collect()
});
let request = OllamaRequest {
model: self.model.clone(),
messages: ollama_messages,
stream: true,
tools: ollama_tools,
};
let stream = stream! {
// デバッグ情報を送信
if let Some(ref debug) = llm_debug {
if let Some(debug_event) = debug.debug_request(&self.model, "Ollama", &serde_json::to_value(&request).unwrap_or_default()) {
yield Ok(debug_event);
}
}
// リクエスト情報をログに出力
tracing::info!("Ollama chat_stream: Sending request to {}", &url);
tracing::debug!("Ollama request model: {}", &request.model);
tracing::debug!("Ollama request messages count: {}", request.messages.len());
if let Some(ref tools) = request.tools {
tracing::debug!("Ollama request tools count: {}", tools.len());
}
// リクエストの詳細ログは削除(見づらさ解消のため)
let request_builder = self.add_auth_header(client.post(&url));
let response = request_builder
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| {
tracing::error!("Ollama chat_stream request failed: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
});
let response = match response {
Ok(resp) => resp,
Err(e) => {
yield Err(e);
return;
}
};
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
tracing::error!("Ollama chat_stream failed - Status: {}, Body: {}, URL: {}", status, error_body, &url);
yield Err(WorkerError::from_api_error(
format!("Ollama API error: {} - {}", status, error_body),
&crate::types::LlmProvider::Ollama,
));
return;
} else {
tracing::info!("Ollama chat_stream response status: {}", response.status());
}
let mut byte_stream = response.bytes_stream();
let mut buffer = String::new();
let mut full_content = String::new();
let mut chunk_count = 0;
tracing::debug!("Ollama chat_stream: Starting to process response stream");
while let Some(chunk) = byte_stream.next().await {
match chunk {
Ok(bytes) => {
chunk_count += 1;
let chunk_str = String::from_utf8_lossy(&bytes);
// Chunk詳細ログは削除見づらさ解消のため
buffer.push_str(&chunk_str);
// Process line by line
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].to_string();
buffer = buffer[line_end + 1..].to_string();
if line.trim().is_empty() {
continue;
}
// Stream行詳細ログは削除見づらさ解消のため
match serde_json::from_str::<OllamaStreamResponse>(&line) {
Ok(stream_response) => {
// デバッグ情報を送信
if let Some(ref debug) = llm_debug {
if let Some(debug_event) = debug.debug_response(&self.model, "Ollama", &serde_json::to_value(&stream_response).unwrap_or_default()) {
yield Ok(debug_event);
}
}
// Handle tool calls
if let Some(tool_calls) = &stream_response.message.tool_calls {
tracing::info!("Ollama stream response contains {} tool calls", tool_calls.len());
for (i, tool_call) in tool_calls.iter().enumerate() {
tracing::debug!("Tool call #{}: name={}, arguments={}",
i + 1, tool_call.function.name, tool_call.function.arguments);
let parsed_tool_call = ToolCall {
name: tool_call.function.name.clone(),
arguments: tool_call.function.arguments.clone(),
};
yield Ok(StreamEvent::ToolCall(parsed_tool_call));
}
}
// Handle regular content
if !stream_response.message.content.is_empty() {
full_content.push_str(&stream_response.message.content);
yield Ok(StreamEvent::Chunk(stream_response.message.content));
}
if stream_response.done {
tracing::info!("Ollama stream completed, total content: {} chars", full_content.len());
tracing::debug!("Ollama complete response content: {}", full_content);
yield Ok(StreamEvent::Completion(Message::new(
Role::Model,
full_content.clone(),
)));
break;
}
}
Err(e) => {
tracing::warn!("Failed to parse Ollama stream response: {} - Line: {}", e, line);
tracing::debug!("Parse error details: line_length={}, error={}", line.len(), e);
}
}
}
}
Err(e) => {
tracing::error!("Ollama stream error after {} chunks: {}", chunk_count, e);
yield Err(WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama));
break;
}
}
}
tracing::debug!("Ollama chat_stream: Stream ended, processed {} chunks", chunk_count);
};
Ok(Box::new(Box::pin(stream)))
}
pub async fn get_model_details(
&self,
model_name: &str,
) -> Result<crate::types::ModelInfo, WorkerError> {
let client = Client::new();
let url = format!("{}/api/show", self.base_url);
let request = serde_json::json!({
"name": model_name
});
let response = self
.add_auth_header(client.post(&url))
.json(&request)
.send()
.await
.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
if !response.status().is_success() {
return Err(WorkerError::from_api_error(
format!(
"Ollama model details request failed with status: {}",
response.status()
),
&crate::types::LlmProvider::Ollama,
));
}
let model_data: serde_json::Value = response.json().await.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
})?;
let details = model_data
.get("details")
.unwrap_or(&serde_json::Value::Null);
let family = details
.get("family")
.and_then(|f| f.as_str())
.unwrap_or("unknown");
let parameter_size = details
.get("parameter_size")
.and_then(|p| p.as_str())
.unwrap_or("unknown");
let quantization = details
.get("quantization_level")
.and_then(|q| q.as_str())
.unwrap_or("unknown");
let size = model_data.get("size").and_then(|s| s.as_u64()).unwrap_or(0);
let modified_at = model_data
.get("modified_at")
.and_then(|m| m.as_str())
.map(|s| s.to_string());
let supports_tools = true; // Will be determined by config
let context_length = None; // Will be determined by config
let capabilities = vec!["text_generation".to_string()]; // Basic default
let description = format!("Ollama model: {}", model_name);
Ok(crate::types::ModelInfo {
id: model_name.to_string(),
name: format!("{} ({}, {})", model_name, family, parameter_size),
provider: crate::types::LlmProvider::Ollama,
supports_tools,
supports_function_calling: supports_tools,
supports_vision: false, // Will be determined dynamically
supports_multimodal: false,
context_length,
training_cutoff: modified_at,
capabilities,
description: Some(format!(
"{} (Size: {} bytes, Quantization: {})",
description, size, quantization
)),
})
}
pub async fn check_connection(&self) -> Result<(), WorkerError> {
let client = Client::new();
let url = format!("{}/api/tags", self.base_url);
self.add_auth_header(client.get(&url))
.send()
.await
.map_err(|e| WorkerError::LlmApiError(format!("Failed to connect to Ollama: {}", e)))?;
Ok(())
}
}
#[async_trait::async_trait]
impl LlmClientTrait for OllamaClient {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
self.chat_stream(messages, tools, llm_debug).await
}
async fn check_connection(&self) -> Result<(), WorkerError> {
self.check_connection().await
}
fn provider(&self) -> LlmProvider {
LlmProvider::Ollama
}
fn get_model_name(&self) -> String {
self.get_model_name()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Message, Role, ToolCall};
#[test]
fn test_message_conversion_with_tool_calls() {
let tool_call = ToolCall {
name: "List".to_string(),
arguments: r#"{"path": "./"}"#.to_string(),
};
let message = Message::with_tool_calls(
Role::Model,
"".to_string(), // Empty content, only tool calls
vec![tool_call.clone()],
);
let messages = vec![message];
// Simulate the conversion that happens in chat_stream
let ollama_messages: Vec<OllamaMessage> = messages
.into_iter()
.map(|msg| {
// Convert tool calls if present
let tool_calls = msg.tool_calls.map(|calls| {
calls
.into_iter()
.map(|call| OllamaToolCall {
function: OllamaToolCallFunction {
name: call.name,
arguments: call.arguments,
},
})
.collect()
});
OllamaMessage {
role: "assistant".to_string(),
content: msg.content,
tool_calls,
}
})
.collect();
// Verify the conversion preserved tool calls
assert_eq!(ollama_messages.len(), 1);
let converted_msg = &ollama_messages[0];
assert_eq!(converted_msg.role, "assistant");
assert_eq!(converted_msg.content, "");
assert!(converted_msg.tool_calls.is_some());
let converted_tool_calls = converted_msg.tool_calls.as_ref().unwrap();
assert_eq!(converted_tool_calls.len(), 1);
assert_eq!(converted_tool_calls[0].function.name, "List");
assert_eq!(
converted_tool_calls[0].function.arguments,
r#"{"path": "./"}"#
);
}
#[test]
fn test_message_conversion_without_tool_calls() {
let message = Message::new(Role::User, "Hello".to_string());
let messages = vec![message];
let ollama_messages: Vec<OllamaMessage> = messages
.into_iter()
.map(|msg| {
let tool_calls = msg.tool_calls.map(|calls| {
calls
.into_iter()
.map(|call| OllamaToolCall {
function: OllamaToolCallFunction {
name: call.name,
arguments: call.arguments,
},
})
.collect()
});
OllamaMessage {
role: "user".to_string(),
content: msg.content,
tool_calls,
}
})
.collect();
assert_eq!(ollama_messages.len(), 1);
let converted_msg = &ollama_messages[0];
assert_eq!(converted_msg.role, "user");
assert_eq!(converted_msg.content, "Hello");
assert!(converted_msg.tool_calls.is_none());
}
}