802 lines
30 KiB
Rust
802 lines
30 KiB
Rust
use crate::{
|
||
LlmClientTrait, WorkerError,
|
||
types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall},
|
||
url_config::UrlConfig,
|
||
};
|
||
use futures_util::{Stream, StreamExt};
|
||
use reqwest::Client;
|
||
use serde::{Deserialize, Serialize};
|
||
use serde_json::Value;
|
||
|
||
// --- Request & Response Structures ---
|
||
#[derive(Debug, Serialize, Clone)]
|
||
pub struct OllamaRequest {
|
||
pub model: String,
|
||
pub messages: Vec<OllamaMessage>,
|
||
pub stream: bool,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub tools: Option<Vec<OllamaTool>>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||
pub struct OllamaMessage {
|
||
pub role: String,
|
||
pub content: String,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub tool_calls: Option<Vec<OllamaToolCall>>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||
pub struct OllamaToolCall {
|
||
pub function: OllamaToolCallFunction,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||
pub struct OllamaToolCallFunction {
|
||
pub name: String,
|
||
#[serde(
|
||
serialize_with = "serialize_arguments",
|
||
deserialize_with = "deserialize_arguments"
|
||
)]
|
||
pub arguments: String,
|
||
}
|
||
|
||
/// Custom serializer for arguments field that serializes strings as-is
|
||
fn serialize_arguments<S>(arguments: &str, serializer: S) -> Result<S::Ok, S::Error>
|
||
where
|
||
S: serde::Serializer,
|
||
{
|
||
// Try to parse as JSON first, if successful serialize as raw JSON
|
||
// If not valid JSON, serialize as string
|
||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(arguments) {
|
||
value.serialize(serializer)
|
||
} else {
|
||
arguments.serialize(serializer)
|
||
}
|
||
}
|
||
|
||
/// Custom deserializer for arguments field that handles both string and object formats
|
||
fn deserialize_arguments<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||
where
|
||
D: serde::Deserializer<'de>,
|
||
{
|
||
use serde::de::Error;
|
||
|
||
let value: serde_json::Value = serde::Deserialize::deserialize(deserializer)?;
|
||
|
||
match value {
|
||
serde_json::Value::String(s) => Ok(s),
|
||
serde_json::Value::Object(_) | serde_json::Value::Array(_) => {
|
||
// If it's an object or array, serialize it back to a JSON string
|
||
serde_json::to_string(&value).map_err(D::Error::custom)
|
||
}
|
||
_ => Err(D::Error::custom("arguments must be a string or object")),
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Clone)]
|
||
pub struct OllamaTool {
|
||
#[serde(rename = "type")]
|
||
pub tool_type: String,
|
||
pub function: OllamaFunction,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Clone)]
|
||
pub struct OllamaFunction {
|
||
pub name: String,
|
||
pub description: String,
|
||
pub parameters: Value,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct OllamaResponse {
|
||
pub message: OllamaMessage,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize, Serialize)]
|
||
pub struct OllamaStreamResponse {
|
||
pub message: OllamaMessage,
|
||
pub done: bool,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct OllamaModelShowResponse {
|
||
pub details: Option<OllamaModelDetails>,
|
||
pub model_info: Option<serde_json::Value>,
|
||
pub template: Option<String>,
|
||
pub system: Option<String>,
|
||
pub parameters: Option<serde_json::Value>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct OllamaModelDetails {
|
||
pub format: Option<String>,
|
||
pub family: Option<String>,
|
||
pub families: Option<Vec<String>>,
|
||
pub parameter_size: Option<String>,
|
||
pub quantization_level: Option<String>,
|
||
}
|
||
|
||
// --- Client ---
|
||
pub struct OllamaClient {
|
||
model: String,
|
||
base_url: String,
|
||
api_key: Option<String>,
|
||
}
|
||
|
||
impl OllamaClient {
|
||
pub fn new(model: &str) -> Self {
|
||
Self {
|
||
model: model.to_string(),
|
||
base_url: UrlConfig::get_base_url("ollama"),
|
||
api_key: None,
|
||
}
|
||
}
|
||
|
||
pub fn new_with_key(api_key: &str, model: &str) -> Self {
|
||
tracing::debug!(
|
||
"Ollama: Creating client with API key (length: {}), model: {}",
|
||
api_key.len(),
|
||
model
|
||
);
|
||
Self {
|
||
model: model.to_string(),
|
||
base_url: UrlConfig::get_base_url("ollama"),
|
||
api_key: Some(api_key.to_string()),
|
||
}
|
||
}
|
||
|
||
pub fn get_model_name(&self) -> String {
|
||
self.model.clone()
|
||
}
|
||
|
||
fn add_auth_header(&self, request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||
tracing::debug!(
|
||
"Ollama: add_auth_header called, api_key present: {}",
|
||
self.api_key.is_some()
|
||
);
|
||
|
||
if let Some(ref api_key) = self.api_key {
|
||
// API key詳細ログは削除(セキュリティと見づらさ解消のため)
|
||
// API keyが空でない場合のみヘッダーを追加
|
||
if !api_key.trim().is_empty() {
|
||
// API keyがすでにフォーマットされているかチェック
|
||
if api_key.starts_with("Basic ") || api_key.starts_with("Bearer ") {
|
||
// すでにフォーマット済み(例: "Basic base64string" や "Bearer token")
|
||
// Auth header詳細ログは削除(セキュリティと見づらさ解消のため)
|
||
request_builder.header("Authorization", api_key)
|
||
} else {
|
||
// URLに基づいて認証方式を決定
|
||
let auth_header = if self.base_url.contains("ollama.com") {
|
||
// ollama.comの場合はBearerトークンを使用
|
||
format!("Bearer {}", api_key)
|
||
} else {
|
||
// その他の場合はBasic認証を使用(ローカル/プロキシ向け)
|
||
format!("Basic {}", api_key)
|
||
};
|
||
// Auth header詳細ログは削除(セキュリティと見づらさ解消のため)
|
||
request_builder.header("Authorization", auth_header)
|
||
}
|
||
} else {
|
||
tracing::debug!("Ollama: Empty API key, skipping auth header");
|
||
request_builder
|
||
}
|
||
} else {
|
||
tracing::debug!(
|
||
"Ollama: No API key provided, using unauthenticated request (typical for local Ollama)"
|
||
);
|
||
request_builder
|
||
}
|
||
}
|
||
|
||
/// 静的メソッド:Ollamaサーバーからモデル一覧を取得(デフォルトURL使用)
|
||
pub async fn list_models_static(
|
||
api_key: &str,
|
||
) -> Result<Vec<crate::types::ModelInfo>, WorkerError> {
|
||
let client = Client::new();
|
||
let url = UrlConfig::get_models_url("ollama");
|
||
|
||
tracing::debug!("Ollama list_models_static requesting: {}", url);
|
||
|
||
let mut request_builder = client.get(&url);
|
||
if !api_key.trim().is_empty() {
|
||
// API keyがすでにフォーマットされているかチェック
|
||
if api_key.starts_with("Basic ") || api_key.starts_with("Bearer ") {
|
||
// すでにフォーマット済み
|
||
request_builder = request_builder.header("Authorization", api_key);
|
||
} else {
|
||
// URLに基づいて認証方式を決定
|
||
let auth_header = if url.contains("ollama.com") {
|
||
// ollama.comの場合はBearerトークンを使用
|
||
format!("Bearer {}", api_key)
|
||
} else {
|
||
// その他の場合はBasic認証を使用(ローカル/プロキシ向け)
|
||
format!("Basic {}", api_key)
|
||
};
|
||
// Auth header詳細ログは削除(セキュリティと見づらさ解消のため)
|
||
request_builder = request_builder.header("Authorization", auth_header);
|
||
}
|
||
}
|
||
|
||
let response = request_builder.send().await.map_err(|e| {
|
||
tracing::error!("Ollama API request failed: {}", e);
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
|
||
})?;
|
||
|
||
let status = response.status();
|
||
tracing::info!("Ollama list_models_static response status: {}", status);
|
||
|
||
if !status.is_success() {
|
||
let error_body = response.text().await.unwrap_or_default();
|
||
tracing::error!(
|
||
"Ollama list_models_static failed - Status: {}, Body: {}",
|
||
status,
|
||
error_body
|
||
);
|
||
let error_msg = format!("Failed to list Ollama models: {} - {}", status, error_body);
|
||
return Err(WorkerError::from_api_error(
|
||
error_msg,
|
||
&crate::types::LlmProvider::Ollama,
|
||
));
|
||
}
|
||
|
||
let response_text = response.text().await.map_err(|e| {
|
||
tracing::error!("Failed to read Ollama response text: {}", e);
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
|
||
})?;
|
||
|
||
// Raw response詳細ログは削除(見づらさ解消のため)
|
||
|
||
let models_response: serde_json::Value =
|
||
serde_json::from_str(&response_text).map_err(|e| {
|
||
tracing::error!(
|
||
"Failed to parse Ollama JSON response: {} - Response: {}",
|
||
e,
|
||
response_text
|
||
);
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
|
||
})?;
|
||
|
||
let mut models = Vec::new();
|
||
|
||
if let Some(models_array) = models_response.get("models").and_then(|m| m.as_array()) {
|
||
for model in models_array {
|
||
if let Some(name) = model.get("name").and_then(|n| n.as_str()) {
|
||
models.push(crate::types::ModelInfo {
|
||
id: name.to_string(),
|
||
name: name.to_string(),
|
||
provider: crate::types::LlmProvider::Ollama,
|
||
supports_tools: true, // Will be determined by config
|
||
supports_function_calling: true,
|
||
supports_vision: false, // Will be determined by config
|
||
supports_multimodal: false,
|
||
context_length: None,
|
||
training_cutoff: None,
|
||
capabilities: vec!["text_generation".to_string()],
|
||
description: Some(format!("Ollama model: {}", name)),
|
||
});
|
||
}
|
||
}
|
||
}
|
||
|
||
tracing::info!(
|
||
"Ollama list_models_static found {} models with metadata",
|
||
models.len()
|
||
);
|
||
Ok(models)
|
||
}
|
||
|
||
// list_models_with_info was removed - models should be configured in models.yaml
|
||
// This private method is kept for future reference if needed
|
||
#[allow(dead_code)]
|
||
async fn list_models_with_info_internal(
|
||
&self,
|
||
) -> Result<Vec<crate::types::ModelInfo>, WorkerError> {
|
||
let client = Client::new();
|
||
let url = format!("{}/api/tags", self.base_url);
|
||
|
||
tracing::debug!("Ollama list_models requesting: {}", url);
|
||
|
||
let request = self.add_auth_header(client.get(&url));
|
||
tracing::debug!("Ollama list_models_with_info sending request to: {}", &url);
|
||
|
||
let response = request.send().await.map_err(|e| {
|
||
tracing::error!("Ollama API request failed: {}", e);
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
|
||
})?;
|
||
|
||
let status = response.status();
|
||
tracing::info!("Ollama list_models response status: {}", status);
|
||
|
||
if !status.is_success() {
|
||
let error_body = response.text().await.unwrap_or_default();
|
||
tracing::error!(
|
||
"Ollama list_models failed - Status: {}, Body: {}, URL: {}",
|
||
status,
|
||
error_body,
|
||
&url
|
||
);
|
||
let error_msg = format!("Failed to list Ollama models: {} - {}", status, error_body);
|
||
return Err(WorkerError::from_api_error(
|
||
error_msg,
|
||
&crate::types::LlmProvider::Ollama,
|
||
));
|
||
}
|
||
|
||
let response_text = response.text().await.map_err(|e| {
|
||
tracing::error!("Failed to read Ollama response text: {}", e);
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
|
||
})?;
|
||
|
||
// Raw response詳細ログは削除(見づらさ解消のため)
|
||
|
||
let models_response: serde_json::Value =
|
||
serde_json::from_str(&response_text).map_err(|e| {
|
||
tracing::error!(
|
||
"Failed to parse Ollama JSON response: {} - Response: {}",
|
||
e,
|
||
response_text
|
||
);
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
|
||
})?;
|
||
|
||
let model_names: Vec<String> = models_response
|
||
.get("models")
|
||
.and_then(|models| models.as_array())
|
||
.ok_or_else(|| {
|
||
tracing::error!("Invalid Ollama models response format - missing 'models' array");
|
||
WorkerError::LlmApiError("Invalid models response format".to_string())
|
||
})?
|
||
.iter()
|
||
.filter_map(|model| {
|
||
model
|
||
.get("name")
|
||
.and_then(|name| name.as_str())
|
||
.map(|s| s.to_string())
|
||
})
|
||
.collect();
|
||
|
||
// Process models concurrently to get detailed information
|
||
let mut models = Vec::new();
|
||
for name in model_names {
|
||
models.push(crate::types::ModelInfo {
|
||
id: name.clone(),
|
||
name: name.clone(),
|
||
provider: crate::types::LlmProvider::Ollama,
|
||
supports_tools: true, // Will be determined by config
|
||
supports_function_calling: true,
|
||
supports_vision: false, // Will be determined by config
|
||
supports_multimodal: false,
|
||
context_length: None,
|
||
training_cutoff: None,
|
||
capabilities: vec!["text_generation".to_string()],
|
||
description: Some(format!("Ollama model: {}", name)),
|
||
});
|
||
}
|
||
|
||
tracing::info!(
|
||
"Ollama list_models found {} models with dynamic capability detection",
|
||
models.len()
|
||
);
|
||
Ok(models)
|
||
}
|
||
}
|
||
|
||
use async_stream::stream;
|
||
|
||
impl OllamaClient {
|
||
pub async fn chat_stream<'a>(
|
||
&'a self,
|
||
messages: Vec<Message>,
|
||
tools: Option<&[crate::types::DynamicToolDefinition]>,
|
||
llm_debug: Option<crate::types::LlmDebug>,
|
||
) -> Result<
|
||
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
|
||
WorkerError,
|
||
> {
|
||
let client = Client::new();
|
||
let url = format!("{}/api/chat", self.base_url);
|
||
|
||
let ollama_messages: Vec<OllamaMessage> = messages
|
||
.into_iter()
|
||
.map(|msg| {
|
||
// Convert tool calls if present
|
||
let tool_calls = msg.tool_calls.map(|calls| {
|
||
calls
|
||
.into_iter()
|
||
.map(|call| OllamaToolCall {
|
||
function: OllamaToolCallFunction {
|
||
name: call.name,
|
||
arguments: call.arguments,
|
||
},
|
||
})
|
||
.collect()
|
||
});
|
||
|
||
OllamaMessage {
|
||
role: match msg.role {
|
||
Role::User => "user".to_string(),
|
||
Role::Model => "assistant".to_string(),
|
||
Role::System => "system".to_string(),
|
||
Role::Tool => "tool".to_string(),
|
||
},
|
||
content: msg.content,
|
||
tool_calls,
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
// Convert tools to Ollama format (similar to OpenAI)
|
||
let ollama_tools = tools.map(|tools| {
|
||
tools
|
||
.iter()
|
||
.map(|tool| OllamaTool {
|
||
tool_type: "function".to_string(),
|
||
function: OllamaFunction {
|
||
name: tool.name.clone(),
|
||
description: tool.description.clone(),
|
||
parameters: tool.parameters_schema.clone(),
|
||
},
|
||
})
|
||
.collect()
|
||
});
|
||
|
||
let request = OllamaRequest {
|
||
model: self.model.clone(),
|
||
messages: ollama_messages,
|
||
stream: true,
|
||
tools: ollama_tools,
|
||
};
|
||
|
||
let stream = stream! {
|
||
// デバッグ情報を送信
|
||
if let Some(ref debug) = llm_debug {
|
||
if let Some(debug_event) = debug.debug_request(&self.model, "Ollama", &serde_json::to_value(&request).unwrap_or_default()) {
|
||
yield Ok(debug_event);
|
||
}
|
||
}
|
||
|
||
// リクエスト情報をログに出力
|
||
tracing::info!("Ollama chat_stream: Sending request to {}", &url);
|
||
tracing::debug!("Ollama request model: {}", &request.model);
|
||
tracing::debug!("Ollama request messages count: {}", request.messages.len());
|
||
if let Some(ref tools) = request.tools {
|
||
tracing::debug!("Ollama request tools count: {}", tools.len());
|
||
}
|
||
|
||
// リクエストの詳細ログは削除(見づらさ解消のため)
|
||
|
||
let request_builder = self.add_auth_header(client.post(&url));
|
||
|
||
let response = request_builder
|
||
.header("Content-Type", "application/json")
|
||
.json(&request)
|
||
.send()
|
||
.await
|
||
.map_err(|e| {
|
||
tracing::error!("Ollama chat_stream request failed: {}", e);
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
|
||
});
|
||
|
||
let response = match response {
|
||
Ok(resp) => resp,
|
||
Err(e) => {
|
||
yield Err(e);
|
||
return;
|
||
}
|
||
};
|
||
|
||
if !response.status().is_success() {
|
||
let status = response.status();
|
||
let error_body = response.text().await.unwrap_or_default();
|
||
tracing::error!("Ollama chat_stream failed - Status: {}, Body: {}, URL: {}", status, error_body, &url);
|
||
yield Err(WorkerError::from_api_error(
|
||
format!("Ollama API error: {} - {}", status, error_body),
|
||
&crate::types::LlmProvider::Ollama,
|
||
));
|
||
return;
|
||
} else {
|
||
tracing::info!("Ollama chat_stream response status: {}", response.status());
|
||
}
|
||
|
||
let mut byte_stream = response.bytes_stream();
|
||
let mut buffer = String::new();
|
||
let mut full_content = String::new();
|
||
let mut chunk_count = 0;
|
||
|
||
tracing::debug!("Ollama chat_stream: Starting to process response stream");
|
||
|
||
while let Some(chunk) = byte_stream.next().await {
|
||
match chunk {
|
||
Ok(bytes) => {
|
||
chunk_count += 1;
|
||
let chunk_str = String::from_utf8_lossy(&bytes);
|
||
// Chunk詳細ログは削除(見づらさ解消のため)
|
||
buffer.push_str(&chunk_str);
|
||
|
||
// Process line by line
|
||
while let Some(line_end) = buffer.find('\n') {
|
||
let line = buffer[..line_end].to_string();
|
||
buffer = buffer[line_end + 1..].to_string();
|
||
|
||
if line.trim().is_empty() {
|
||
continue;
|
||
}
|
||
|
||
// Stream行詳細ログは削除(見づらさ解消のため)
|
||
|
||
match serde_json::from_str::<OllamaStreamResponse>(&line) {
|
||
Ok(stream_response) => {
|
||
// デバッグ情報を送信
|
||
if let Some(ref debug) = llm_debug {
|
||
if let Some(debug_event) = debug.debug_response(&self.model, "Ollama", &serde_json::to_value(&stream_response).unwrap_or_default()) {
|
||
yield Ok(debug_event);
|
||
}
|
||
}
|
||
|
||
// Handle tool calls
|
||
if let Some(tool_calls) = &stream_response.message.tool_calls {
|
||
tracing::info!("Ollama stream response contains {} tool calls", tool_calls.len());
|
||
for (i, tool_call) in tool_calls.iter().enumerate() {
|
||
tracing::debug!("Tool call #{}: name={}, arguments={}",
|
||
i + 1, tool_call.function.name, tool_call.function.arguments);
|
||
let parsed_tool_call = ToolCall {
|
||
name: tool_call.function.name.clone(),
|
||
arguments: tool_call.function.arguments.clone(),
|
||
};
|
||
yield Ok(StreamEvent::ToolCall(parsed_tool_call));
|
||
}
|
||
}
|
||
|
||
// Handle regular content
|
||
if !stream_response.message.content.is_empty() {
|
||
full_content.push_str(&stream_response.message.content);
|
||
yield Ok(StreamEvent::Chunk(stream_response.message.content));
|
||
}
|
||
|
||
if stream_response.done {
|
||
tracing::info!("Ollama stream completed, total content: {} chars", full_content.len());
|
||
tracing::debug!("Ollama complete response content: {}", full_content);
|
||
yield Ok(StreamEvent::Completion(Message::new(
|
||
Role::Model,
|
||
full_content.clone(),
|
||
)));
|
||
break;
|
||
}
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!("Failed to parse Ollama stream response: {} - Line: {}", e, line);
|
||
tracing::debug!("Parse error details: line_length={}, error={}", line.len(), e);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
Err(e) => {
|
||
tracing::error!("Ollama stream error after {} chunks: {}", chunk_count, e);
|
||
yield Err(WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama));
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
tracing::debug!("Ollama chat_stream: Stream ended, processed {} chunks", chunk_count);
|
||
};
|
||
|
||
Ok(Box::new(Box::pin(stream)))
|
||
}
|
||
|
||
pub async fn get_model_details(
|
||
&self,
|
||
model_name: &str,
|
||
) -> Result<crate::types::ModelInfo, WorkerError> {
|
||
let client = Client::new();
|
||
let url = format!("{}/api/show", self.base_url);
|
||
|
||
let request = serde_json::json!({
|
||
"name": model_name
|
||
});
|
||
|
||
let response = self
|
||
.add_auth_header(client.post(&url))
|
||
.json(&request)
|
||
.send()
|
||
.await
|
||
.map_err(|e| {
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
|
||
})?;
|
||
|
||
if !response.status().is_success() {
|
||
return Err(WorkerError::from_api_error(
|
||
format!(
|
||
"Ollama model details request failed with status: {}",
|
||
response.status()
|
||
),
|
||
&crate::types::LlmProvider::Ollama,
|
||
));
|
||
}
|
||
|
||
let model_data: serde_json::Value = response.json().await.map_err(|e| {
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Ollama)
|
||
})?;
|
||
|
||
let details = model_data
|
||
.get("details")
|
||
.unwrap_or(&serde_json::Value::Null);
|
||
let family = details
|
||
.get("family")
|
||
.and_then(|f| f.as_str())
|
||
.unwrap_or("unknown");
|
||
let parameter_size = details
|
||
.get("parameter_size")
|
||
.and_then(|p| p.as_str())
|
||
.unwrap_or("unknown");
|
||
let quantization = details
|
||
.get("quantization_level")
|
||
.and_then(|q| q.as_str())
|
||
.unwrap_or("unknown");
|
||
|
||
let size = model_data.get("size").and_then(|s| s.as_u64()).unwrap_or(0);
|
||
|
||
let modified_at = model_data
|
||
.get("modified_at")
|
||
.and_then(|m| m.as_str())
|
||
.map(|s| s.to_string());
|
||
|
||
let supports_tools = true; // Will be determined by config
|
||
let context_length = None; // Will be determined by config
|
||
let capabilities = vec!["text_generation".to_string()]; // Basic default
|
||
let description = format!("Ollama model: {}", model_name);
|
||
|
||
Ok(crate::types::ModelInfo {
|
||
id: model_name.to_string(),
|
||
name: format!("{} ({}, {})", model_name, family, parameter_size),
|
||
provider: crate::types::LlmProvider::Ollama,
|
||
supports_tools,
|
||
supports_function_calling: supports_tools,
|
||
supports_vision: false, // Will be determined dynamically
|
||
supports_multimodal: false,
|
||
context_length,
|
||
training_cutoff: modified_at,
|
||
capabilities,
|
||
description: Some(format!(
|
||
"{} (Size: {} bytes, Quantization: {})",
|
||
description, size, quantization
|
||
)),
|
||
})
|
||
}
|
||
|
||
pub async fn check_connection(&self) -> Result<(), WorkerError> {
|
||
let client = Client::new();
|
||
let url = format!("{}/api/tags", self.base_url);
|
||
self.add_auth_header(client.get(&url))
|
||
.send()
|
||
.await
|
||
.map_err(|e| WorkerError::LlmApiError(format!("Failed to connect to Ollama: {}", e)))?;
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl LlmClientTrait for OllamaClient {
|
||
async fn chat_stream<'a>(
|
||
&'a self,
|
||
messages: Vec<Message>,
|
||
tools: Option<&[DynamicToolDefinition]>,
|
||
llm_debug: Option<crate::types::LlmDebug>,
|
||
) -> Result<
|
||
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
|
||
WorkerError,
|
||
> {
|
||
self.chat_stream(messages, tools, llm_debug).await
|
||
}
|
||
|
||
async fn check_connection(&self) -> Result<(), WorkerError> {
|
||
self.check_connection().await
|
||
}
|
||
|
||
fn provider(&self) -> LlmProvider {
|
||
LlmProvider::Ollama
|
||
}
|
||
|
||
fn get_model_name(&self) -> String {
|
||
self.get_model_name()
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::types::{Message, Role, ToolCall};
|
||
|
||
#[test]
|
||
fn test_message_conversion_with_tool_calls() {
|
||
let tool_call = ToolCall {
|
||
name: "List".to_string(),
|
||
arguments: r#"{"path": "./"}"#.to_string(),
|
||
};
|
||
|
||
let message = Message::with_tool_calls(
|
||
Role::Model,
|
||
"".to_string(), // Empty content, only tool calls
|
||
vec![tool_call.clone()],
|
||
);
|
||
|
||
let messages = vec![message];
|
||
|
||
// Simulate the conversion that happens in chat_stream
|
||
let ollama_messages: Vec<OllamaMessage> = messages
|
||
.into_iter()
|
||
.map(|msg| {
|
||
// Convert tool calls if present
|
||
let tool_calls = msg.tool_calls.map(|calls| {
|
||
calls
|
||
.into_iter()
|
||
.map(|call| OllamaToolCall {
|
||
function: OllamaToolCallFunction {
|
||
name: call.name,
|
||
arguments: call.arguments,
|
||
},
|
||
})
|
||
.collect()
|
||
});
|
||
|
||
OllamaMessage {
|
||
role: "assistant".to_string(),
|
||
content: msg.content,
|
||
tool_calls,
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
// Verify the conversion preserved tool calls
|
||
assert_eq!(ollama_messages.len(), 1);
|
||
let converted_msg = &ollama_messages[0];
|
||
assert_eq!(converted_msg.role, "assistant");
|
||
assert_eq!(converted_msg.content, "");
|
||
assert!(converted_msg.tool_calls.is_some());
|
||
|
||
let converted_tool_calls = converted_msg.tool_calls.as_ref().unwrap();
|
||
assert_eq!(converted_tool_calls.len(), 1);
|
||
assert_eq!(converted_tool_calls[0].function.name, "List");
|
||
assert_eq!(
|
||
converted_tool_calls[0].function.arguments,
|
||
r#"{"path": "./"}"#
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_message_conversion_without_tool_calls() {
|
||
let message = Message::new(Role::User, "Hello".to_string());
|
||
let messages = vec![message];
|
||
|
||
let ollama_messages: Vec<OllamaMessage> = messages
|
||
.into_iter()
|
||
.map(|msg| {
|
||
let tool_calls = msg.tool_calls.map(|calls| {
|
||
calls
|
||
.into_iter()
|
||
.map(|call| OllamaToolCall {
|
||
function: OllamaToolCallFunction {
|
||
name: call.name,
|
||
arguments: call.arguments,
|
||
},
|
||
})
|
||
.collect()
|
||
});
|
||
|
||
OllamaMessage {
|
||
role: "user".to_string(),
|
||
content: msg.content,
|
||
tool_calls,
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
assert_eq!(ollama_messages.len(), 1);
|
||
let converted_msg = &ollama_messages[0];
|
||
assert_eq!(converted_msg.role, "user");
|
||
assert_eq!(converted_msg.content, "Hello");
|
||
assert!(converted_msg.tool_calls.is_none());
|
||
}
|
||
}
|