diff --git a/worker-macros/src/lib.rs b/worker-macros/src/lib.rs index c580437..46c31a1 100644 --- a/worker-macros/src/lib.rs +++ b/worker-macros/src/lib.rs @@ -6,7 +6,7 @@ use proc_macro::TokenStream; use quote::{format_ident, quote}; use syn::{ - parse_macro_input, Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, + Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, parse_macro_input, }; /// `impl` ブロックに付与し、内部の `#[tool]` 属性がついたメソッドからツールを生成するマクロ。 @@ -311,7 +311,7 @@ pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream { } /// 引数属性用のマーカー。パース時に`tool_registry`で解釈される。 -/// +/// /// # Example /// ```ignore /// #[tool] diff --git a/worker-types/src/hook.rs b/worker-types/src/hook.rs index 7c3dd35..d658cca 100644 --- a/worker-types/src/hook.rs +++ b/worker-types/src/hook.rs @@ -127,7 +127,10 @@ pub trait WorkerHook: Send + Sync { /// ツール実行後 /// /// 結果を書き換えたり、隠蔽したりできる。 - async fn after_tool_call(&self, _tool_result: &mut ToolResult) -> Result { + async fn after_tool_call( + &self, + _tool_result: &mut ToolResult, + ) -> Result { Ok(ControlFlow::Continue) } diff --git a/worker-types/src/message.rs b/worker-types/src/message.rs index 6981842..dc66909 100644 --- a/worker-types/src/message.rs +++ b/worker-types/src/message.rs @@ -54,7 +54,10 @@ pub enum ContentPart { }, /// ツール結果 #[serde(rename = "tool_result")] - ToolResult { tool_use_id: String, content: String }, + ToolResult { + tool_use_id: String, + content: String, + }, } impl Message { diff --git a/worker-types/src/subscriber.rs b/worker-types/src/subscriber.rs index 7d87c86..ac62ef1 100644 --- a/worker-types/src/subscriber.rs +++ b/worker-types/src/subscriber.rs @@ -3,9 +3,7 @@ //! Timeline層のHandler機構の薄いラッパーとして設計され、 //! UIへのストリーミング表示やリアルタイムフィードバックを可能にする。 -use crate::{ - ErrorEvent, StatusEvent, TextBlockEvent, ToolCall, ToolUseBlockEvent, UsageEvent, -}; +use crate::{ErrorEvent, StatusEvent, TextBlockEvent, ToolCall, ToolUseBlockEvent, UsageEvent}; // ============================================================================= // WorkerSubscriber Trait @@ -74,7 +72,11 @@ pub trait WorkerSubscriber: Send { /// /// Start/InputJsonDelta/Stopのライフサイクルを持つ。 #[allow(unused_variables)] - fn on_tool_use_block(&mut self, scope: &mut Self::ToolUseBlockScope, event: &ToolUseBlockEvent) { + fn on_tool_use_block( + &mut self, + scope: &mut Self::ToolUseBlockScope, + event: &ToolUseBlockEvent, + ) { } // ========================================================================= diff --git a/worker/examples/llm_client_gemini.rs b/worker/examples/llm_client_gemini.rs index d7e3f50..3c8fbe9 100644 --- a/worker/examples/llm_client_gemini.rs +++ b/worker/examples/llm_client_gemini.rs @@ -111,8 +111,8 @@ impl Handler for UsageTracker { #[tokio::main] async fn main() -> Result<(), Box> { // APIキーを環境変数から取得 - let api_key = std::env::var("GEMINI_API_KEY") - .expect("GEMINI_API_KEY environment variable must be set"); + let api_key = + std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY environment variable must be set"); println!("=== Gemini LLM Client + Timeline Integration Example ===\n"); diff --git a/worker/examples/record_test_fixtures/main.rs b/worker/examples/record_test_fixtures/main.rs index 90aaf82..a24acec 100644 --- a/worker/examples/record_test_fixtures/main.rs +++ b/worker/examples/record_test_fixtures/main.rs @@ -16,9 +16,6 @@ //! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- --all //! ``` - - - mod recorder; mod scenarios; @@ -82,7 +79,8 @@ async fn run_scenario_with_openai( subdir: &str, model: Option, ) -> Result<(), Box> { - let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY environment variable must be set"); + let api_key = + std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY environment variable must be set"); let model = model.as_deref().unwrap_or("gpt-4o"); let client = OpenAIClient::new(&api_key, model); @@ -125,8 +123,8 @@ async fn run_scenario_with_gemini( subdir: &str, model: Option, ) -> Result<(), Box> { - let api_key = std::env::var("GEMINI_API_KEY") - .expect("GEMINI_API_KEY environment variable must be set"); + let api_key = + std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY environment variable must be set"); let model = model.as_deref().unwrap_or("gemini-2.0-flash"); let client = GeminiClient::new(&api_key, model); @@ -142,9 +140,6 @@ async fn run_scenario_with_gemini( Ok(()) } - - - #[tokio::main] async fn main() -> Result<(), Box> { dotenv::dotenv().ok(); @@ -173,13 +168,13 @@ async fn main() -> Result<(), Box> { .collect(); if found.is_empty() { - eprintln!("Error: Unknown scenario '{}'", scenario_name); - // Verify correct name by listing - println!("Available scenarios:"); - for s in scenarios::scenarios() { - println!(" {}", s.output_name); - } - std::process::exit(1); + eprintln!("Error: Unknown scenario '{}'", scenario_name); + // Verify correct name by listing + println!("Available scenarios:"); + for s in scenarios::scenarios() { + println!(" {}", s.output_name); + } + std::process::exit(1); } found }; @@ -201,12 +196,20 @@ async fn main() -> Result<(), Box> { // シナリオのフィルタリングは main.rs のロジックで実行済み // ここでは単純なループで実行 for scenario in scenarios_to_run { - match args.client { - ClientType::Anthropic => run_scenario_with_anthropic(&scenario, subdir, args.model.clone()).await?, - ClientType::Gemini => run_scenario_with_gemini(&scenario, subdir, args.model.clone()).await?, - ClientType::Openai => run_scenario_with_openai(&scenario, subdir, args.model.clone()).await?, - ClientType::Ollama => run_scenario_with_ollama(&scenario, subdir, args.model.clone()).await?, - } + match args.client { + ClientType::Anthropic => { + run_scenario_with_anthropic(&scenario, subdir, args.model.clone()).await? + } + ClientType::Gemini => { + run_scenario_with_gemini(&scenario, subdir, args.model.clone()).await? + } + ClientType::Openai => { + run_scenario_with_openai(&scenario, subdir, args.model.clone()).await? + } + ClientType::Ollama => { + run_scenario_with_ollama(&scenario, subdir, args.model.clone()).await? + } + } } println!("\n✅ Done!"); diff --git a/worker/examples/worker_cli.rs b/worker/examples/worker_cli.rs index 6faa318..8b6aea7 100644 --- a/worker/examples/worker_cli.rs +++ b/worker/examples/worker_cli.rs @@ -38,14 +38,14 @@ use tracing_subscriber::EnvFilter; use clap::{Parser, ValueEnum}; use worker::{ + Handler, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind, Worker, llm_client::{ + LlmClient, providers::{ anthropic::AnthropicClient, gemini::GeminiClient, ollama::OllamaClient, openai::OpenAIClient, }, - LlmClient, }, - Handler, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind, Worker, }; use worker_macros::tool_registry; use worker_types::Message; @@ -310,9 +310,8 @@ async fn main() -> Result<(), Box> { // ロギング初期化 // RUST_LOG=debug cargo run --example worker_cli ... で詳細ログ表示 // デフォルトは warn レベル、RUST_LOG 環境変数で上書き可能 - let filter = EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new("warn")); - + let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("warn")); + tracing_subscriber::fmt() .with_env_filter(filter) .with_target(true) @@ -320,7 +319,7 @@ async fn main() -> Result<(), Box> { // CLI引数をパース let args = Args::parse(); - + info!( provider = ?args.provider, model = ?args.model, diff --git a/worker/src/llm_client/providers/anthropic.rs b/worker/src/llm_client/providers/anthropic.rs index 70fd74e..5090564 100644 --- a/worker/src/llm_client/providers/anthropic.rs +++ b/worker/src/llm_client/providers/anthropic.rs @@ -6,7 +6,7 @@ use std::pin::Pin; use async_trait::async_trait; use eventsource_stream::Eventsource; -use futures::{future::ready, Stream, StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt, future::ready}; use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue}; use worker_types::Event; @@ -178,7 +178,6 @@ impl LlmClient for AnthropicClient { } } - #[cfg(test)] mod tests { use super::*; diff --git a/worker/src/llm_client/providers/ollama.rs b/worker/src/llm_client/providers/ollama.rs index f889539..e813f7e 100644 --- a/worker/src/llm_client/providers/ollama.rs +++ b/worker/src/llm_client/providers/ollama.rs @@ -10,9 +10,7 @@ use futures::Stream; use worker_types::Event; use crate::llm_client::{ - ClientError, LlmClient, Request, - providers::openai::OpenAIClient, - scheme::openai::OpenAIScheme, + ClientError, LlmClient, Request, providers::openai::OpenAIClient, scheme::openai::OpenAIScheme, }; /// Ollama クライアント @@ -29,7 +27,7 @@ impl OllamaClient { // Ollama usually runs on localhost:11434/v1 // API key is "ollama" or ignored let base_url = "http://localhost:11434"; - + let scheme = OpenAIScheme::new().with_legacy_max_tokens(true); let client = OpenAIClient::new("ollama", model) @@ -37,7 +35,7 @@ impl OllamaClient { .with_scheme(scheme); // Currently OpenAIScheme sets include_usage: true. Ollama supports checks? // Assuming Ollama modern versions support usage. - + Self { inner: client } } @@ -46,7 +44,7 @@ impl OllamaClient { self.inner = self.inner.with_base_url(url); self } - + /// カスタムHTTPクライアントを設定 pub fn with_http_client(mut self, client: reqwest::Client) -> Self { self.inner = self.inner.with_http_client(client); diff --git a/worker/src/llm_client/providers/openai.rs b/worker/src/llm_client/providers/openai.rs index 7a9a576..6da17e1 100644 --- a/worker/src/llm_client/providers/openai.rs +++ b/worker/src/llm_client/providers/openai.rs @@ -61,21 +61,21 @@ impl OpenAIClient { let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - + let api_key_val = if self.api_key.is_empty() { // For providers like Ollama, API key might be empty/dummy. // But typical OpenAI requires it. // We'll allow empty if user intends it, but usually it's checked. - HeaderValue::from_static("") + HeaderValue::from_static("") } else { - let mut val = HeaderValue::from_str(&format!("Bearer {}", self.api_key)) + let mut val = HeaderValue::from_str(&format!("Bearer {}", self.api_key)) .map_err(|e| ClientError::Config(format!("Invalid API key: {}", e)))?; - val.set_sensitive(true); - val + val.set_sensitive(true); + val }; if !api_key_val.is_empty() { - headers.insert("Authorization", api_key_val); + headers.insert("Authorization", api_key_val); } Ok(headers) @@ -92,24 +92,24 @@ impl LlmClient for OpenAIClient { // Standard OpenAI base is "https://api.openai.com". Endpoint is "/v1/chat/completions". // If external base_url includes /v1, we should be careful. // Let's assume defaults. If user provides "http://localhost:11434/v1", we append "/chat/completions". - // Or cleaner: user provides full base up to version? - // Anthropic client uses "{}/v1/messages". - // Let's stick to appending "/v1/chat/completions" if base is just host, + // Or cleaner: user provides full base up to version? + // Anthropic client uses "{}/v1/messages". + // Let's stick to appending "/v1/chat/completions" if base is just host, // OR assume base includes /v1 if user overrides it? // Let's use robust joining or simple assumption matching Anthropic pattern: // Default: https://api.openai.com -> https://api.openai.com/v1/chat/completions - + // However, Ollama default is http://localhost:11434/v1/chat/completions if using OpenAI compact. // If we configure base_url via `with_base_url`, it's flexible. // Let's try to detect if /v1 is present or just append consistently. // Ideally `base_url` should be the root passed to `new`. - + let url = if self.base_url.ends_with("/v1") { - format!("{}/chat/completions", self.base_url) + format!("{}/chat/completions", self.base_url) } else if self.base_url.ends_with("/") { - format!("{}v1/chat/completions", self.base_url) + format!("{}v1/chat/completions", self.base_url) } else { - format!("{}/v1/chat/completions", self.base_url) + format!("{}/v1/chat/completions", self.base_url) }; let headers = self.build_headers()?; @@ -159,40 +159,41 @@ impl LlmClient for OpenAIClient { .map_err(|e| std::io::Error::other(e)); let event_stream = byte_stream.eventsource(); - let stream = event_stream.map(move |result| { - match result { - Ok(event) => { - // SSEイベントをパース - // OpenAI stream events are "data: {...}" - // event.event is usually "message" (default) or empty. - // parse_event takes data string. - - if event.data == "[DONE]" { - // End of stream handled inside parse_event usually returning None - Ok(None) - } else { - match scheme.parse_event(&event.data) { - Ok(Some(events)) => Ok(Some(events)), - Ok(None) => Ok(None), - Err(e) => Err(e), + let stream = event_stream + .map(move |result| { + match result { + Ok(event) => { + // SSEイベントをパース + // OpenAI stream events are "data: {...}" + // event.event is usually "message" (default) or empty. + // parse_event takes data string. + + if event.data == "[DONE]" { + // End of stream handled inside parse_event usually returning None + Ok(None) + } else { + match scheme.parse_event(&event.data) { + Ok(Some(events)) => Ok(Some(events)), + Ok(None) => Ok(None), + Err(e) => Err(e), + } } } + Err(e) => Err(ClientError::Sse(e.to_string())), } - Err(e) => Err(ClientError::Sse(e.to_string())), - } - }) - // flatten Option> stream to Stream - // map returns Result>, Error> - // We want Stream> - .map(|res| { - let s: Pin> + Send>> = match res { - Ok(Some(events)) => Box::pin(futures::stream::iter(events.into_iter().map(Ok))), - Ok(None) => Box::pin(futures::stream::empty()), - Err(e) => Box::pin(futures::stream::once(async move { Err(e) })), - }; - s - }) - .flatten(); + }) + // flatten Option> stream to Stream + // map returns Result>, Error> + // We want Stream> + .map(|res| { + let s: Pin> + Send>> = match res { + Ok(Some(events)) => Box::pin(futures::stream::iter(events.into_iter().map(Ok))), + Ok(None) => Box::pin(futures::stream::empty()), + Err(e) => Box::pin(futures::stream::once(async move { Err(e) })), + }; + s + }) + .flatten(); Ok(Box::pin(stream)) } diff --git a/worker/src/llm_client/scheme/gemini/events.rs b/worker/src/llm_client/scheme/gemini/events.rs index ef7d9f0..0fd1fb7 100644 --- a/worker/src/llm_client/scheme/gemini/events.rs +++ b/worker/src/llm_client/scheme/gemini/events.rs @@ -127,13 +127,12 @@ impl GeminiScheme { return Ok(None); } - let response: GenerateContentResponse = serde_json::from_str(data).map_err(|e| { - ClientError::Api { + let response: GenerateContentResponse = + serde_json::from_str(data).map_err(|e| ClientError::Api { status: None, code: Some("parse_error".to_string()), message: format!("Failed to parse Gemini SSE data: {} -> {}", e, data), - } - })?; + })?; let mut events = Vec::new(); @@ -155,10 +154,7 @@ impl GeminiScheme { if !text.is_empty() { // Geminiは明示的なBlockStartを送らないため、 // TextDeltaを直接送る(Timelineが暗黙的に開始を処理) - events.push(Event::text_delta( - part_index, - text.clone(), - )); + events.push(Event::text_delta(part_index, text.clone())); } } @@ -167,10 +163,10 @@ impl GeminiScheme { // 関数呼び出しの開始 // Geminiでは関数呼び出しは一度に送られることが多い // ストリーミング引数が有効な場合は部分的に送られる可能性がある - + // 関数呼び出しIDはGeminiにはないので、名前をIDとして使用 let function_id = format!("call_{}", function_call.name); - + events.push(Event::BlockStart(BlockStart { index: candidate_index * 10 + part_index, // 複合インデックス block_type: BlockType::ToolUse, @@ -240,7 +236,8 @@ mod tests { #[test] fn test_parse_text_response() { let scheme = GeminiScheme::new(); - let data = r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"index":0}]}"#; + let data = + r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"index":0}]}"#; let events = scheme.parse_event(data).unwrap().unwrap(); assert_eq!(events.len(), 1); @@ -263,7 +260,7 @@ mod tests { let data = r#"{"candidates":[{"content":{"parts":[{"text":"Hi"}],"role":"model"},"index":0}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}"#; let events = scheme.parse_event(data).unwrap().unwrap(); - + // Usageイベントが含まれるはず let usage_event = events.iter().find(|e| matches!(e, Event::Usage(_))); assert!(usage_event.is_some()); @@ -281,7 +278,7 @@ mod tests { let data = r#"{"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_weather","args":{"location":"Tokyo"}}}],"role":"model"},"index":0}]}"#; let events = scheme.parse_event(data).unwrap().unwrap(); - + // BlockStartイベントがあるはず let start_event = events.iter().find(|e| matches!(e, Event::BlockStart(_))); assert!(start_event.is_some()); @@ -312,7 +309,7 @@ mod tests { let data = r#"{"candidates":[{"content":{"parts":[{"text":"Done"}],"role":"model"},"finishReason":"STOP","index":0}]}"#; let events = scheme.parse_event(data).unwrap().unwrap(); - + // BlockStopイベントがあるはず let stop_event = events.iter().find(|e| matches!(e, Event::BlockStop(_))); assert!(stop_event.is_some()); diff --git a/worker/src/llm_client/scheme/gemini/request.rs b/worker/src/llm_client/scheme/gemini/request.rs index 1c2b0ed..6785ea8 100644 --- a/worker/src/llm_client/scheme/gemini/request.rs +++ b/worker/src/llm_client/scheme/gemini/request.rs @@ -46,9 +46,7 @@ pub(crate) struct GeminiContent { #[serde(untagged)] pub(crate) enum GeminiPart { /// テキストパーツ - Text { - text: String, - }, + Text { text: String }, /// 関数呼び出しパーツ FunctionCall { #[serde(rename = "functionCall")] @@ -160,11 +158,7 @@ impl GeminiScheme { vec![] } else { vec![GeminiTool { - function_declarations: request - .tools - .iter() - .map(|t| self.convert_tool(t)) - .collect(), + function_declarations: request.tools.iter().map(|t| self.convert_tool(t)).collect(), }] }; @@ -224,34 +218,30 @@ impl GeminiScheme { }, }] } - MessageContent::Parts(parts) => { - parts - .iter() - .map(|p| match p { - ContentPart::Text { text } => GeminiPart::Text { text: text.clone() }, - ContentPart::ToolUse { id: _, name, input } => { - GeminiPart::FunctionCall { - function_call: GeminiFunctionCall { - name: name.clone(), - args: input.clone(), - }, - } - } - ContentPart::ToolResult { - tool_use_id, - content, - } => GeminiPart::FunctionResponse { - function_response: GeminiFunctionResponse { + MessageContent::Parts(parts) => parts + .iter() + .map(|p| match p { + ContentPart::Text { text } => GeminiPart::Text { text: text.clone() }, + ContentPart::ToolUse { id: _, name, input } => GeminiPart::FunctionCall { + function_call: GeminiFunctionCall { + name: name.clone(), + args: input.clone(), + }, + }, + ContentPart::ToolResult { + tool_use_id, + content, + } => GeminiPart::FunctionResponse { + function_response: GeminiFunctionResponse { + name: tool_use_id.clone(), + response: GeminiFunctionResponseContent { name: tool_use_id.clone(), - response: GeminiFunctionResponseContent { - name: tool_use_id.clone(), - content: serde_json::Value::String(content.clone()), - }, + content: serde_json::Value::String(content.clone()), }, }, - }) - .collect() - } + }, + }) + .collect(), }; GeminiContent { @@ -306,16 +296,17 @@ mod tests { assert_eq!(gemini_req.tools.len(), 1); assert_eq!(gemini_req.tools[0].function_declarations.len(), 1); - assert_eq!(gemini_req.tools[0].function_declarations[0].name, "get_weather"); + assert_eq!( + gemini_req.tools[0].function_declarations[0].name, + "get_weather" + ); assert!(gemini_req.tool_config.is_some()); } #[test] fn test_assistant_role_is_model() { let scheme = GeminiScheme::new(); - let request = Request::new() - .user("Hello") - .assistant("Hi there!"); + let request = Request::new().user("Hello").assistant("Hi there!"); let gemini_req = scheme.build_request(&request); diff --git a/worker/src/llm_client/scheme/openai/events.rs b/worker/src/llm_client/scheme/openai/events.rs index 7df1c37..b7e0eb8 100644 --- a/worker/src/llm_client/scheme/openai/events.rs +++ b/worker/src/llm_client/scheme/openai/events.rs @@ -69,8 +69,8 @@ impl OpenAIScheme { return Ok(None); } - let chunk: ChatCompletionChunk = serde_json::from_str(data) - .map_err(|e| ClientError::Api { + let chunk: ChatCompletionChunk = + serde_json::from_str(data).map_err(|e| ClientError::Api { status: None, code: Some("parse_error".to_string()), message: format!("Failed to parse SSE data: {} -> {}", e, data), @@ -102,10 +102,14 @@ impl OpenAIScheme { for tool_call in tool_calls { // Start of tool call (has ID) if let Some(id) = tool_call.id { - let name = tool_call.function.as_ref().and_then(|f| f.name.clone()).unwrap_or_default(); + let name = tool_call + .function + .as_ref() + .and_then(|f| f.name.clone()) + .unwrap_or_default(); events.push(Event::tool_use_start(tool_call.index, id, name)); } - + // Arguments delta if let Some(function) = tool_call.function { if let Some(args) = function.arguments { @@ -116,7 +120,7 @@ impl OpenAIScheme { } } } - + // Finish Reason if let Some(finish_reason) = choice.finish_reason { let stop_reason = match finish_reason.as_str() { @@ -125,9 +129,10 @@ impl OpenAIScheme { "tool_calls" | "function_call" => Some(StopReason::ToolUse), _ => Some(StopReason::EndTurn), }; - - let is_tool_finish = finish_reason == "tool_calls" || finish_reason == "function_call"; - + + let is_tool_finish = + finish_reason == "tool_calls" || finish_reason == "function_call"; + if is_tool_finish { // ツール呼び出し終了 // Note: OpenAIはどのツールが終了したか明示しないため、 @@ -156,11 +161,11 @@ mod tests { fn test_parse_text_delta() { let scheme = OpenAIScheme::new(); let data = r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#; - + let events = scheme.parse_event(data).unwrap().unwrap(); // OpenAIはBlockStartを発行しないため、デルタのみ assert_eq!(events.len(), 1); - + if let Event::BlockDelta(delta) = &events[0] { assert_eq!(delta.index, 0); if let DeltaContent::Text(text) = &delta.delta { @@ -178,9 +183,9 @@ mod tests { let scheme = OpenAIScheme::new(); // Start of tool call let data_start = r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}"#; - + let events = scheme.parse_event(data_start).unwrap().unwrap(); - assert_eq!(events.len(), 1); + assert_eq!(events.len(), 1); if let Event::BlockStart(start) = &events[0] { assert_eq!(start.index, 0); if let worker_types::BlockMetadata::ToolUse { id, name } = &start.metadata { diff --git a/worker/src/llm_client/scheme/openai/request.rs b/worker/src/llm_client/scheme/openai/request.rs index 2bb58ae..9251094 100644 --- a/worker/src/llm_client/scheme/openai/request.rs +++ b/worker/src/llm_client/scheme/openai/request.rs @@ -120,12 +120,7 @@ impl OpenAIScheme { }); } - messages.extend( - request - .messages - .iter() - .map(|m| self.convert_message(m)) - ); + messages.extend(request.messages.iter().map(|m| self.convert_message(m))); let tools = request.tools.iter().map(|t| self.convert_tool(t)).collect(); @@ -143,7 +138,9 @@ impl OpenAIScheme { top_p: request.config.top_p, stop: request.config.stop_sequences.clone(), stream: true, - stream_options: Some(StreamOptions { include_usage: true }), + stream_options: Some(StreamOptions { + include_usage: true, + }), messages, tools, tool_choice: None, // Default to auto if tools are present? Or let API decide (which is auto) @@ -224,14 +221,14 @@ impl OpenAIScheme { name: None, } } else { - let content = if content_parts.is_empty() { + let content = if content_parts.is_empty() { None } else if content_parts.len() == 1 { // Simplify single text part to just Text content if preferred, or keep as Parts if let OpenAIContentPart::Text { text } = &content_parts[0] { - Some(OpenAIContent::Text(text.clone())) + Some(OpenAIContent::Text(text.clone())) } else { - Some(OpenAIContent::Parts(content_parts)) + Some(OpenAIContent::Parts(content_parts)) } } else { Some(OpenAIContent::Parts(content_parts)) @@ -265,13 +262,10 @@ impl OpenAIScheme { mod tests { use super::*; - #[test] fn test_build_simple_request() { let scheme = OpenAIScheme::new(); - let request = Request::new() - .system("System prompt") - .user("Hello"); + let request = Request::new().system("System prompt").user("Hello"); let body = scheme.build_request("gpt-4o", &request); @@ -279,7 +273,7 @@ mod tests { assert_eq!(body.messages.len(), 2); assert_eq!(body.messages[0].role, "system"); assert_eq!(body.messages[1].role, "user"); - + // Check system content if let Some(OpenAIContent::Text(text)) = &body.messages[0].content { assert_eq!(text, "System prompt"); @@ -303,12 +297,10 @@ mod tests { #[test] fn test_build_request_legacy_max_tokens() { let scheme = OpenAIScheme::new().with_legacy_max_tokens(true); - let request = Request::new() - .user("Hello") - .max_tokens(100); + let request = Request::new().user("Hello").max_tokens(100); let body = scheme.build_request("llama3", &request); - + // max_tokens should be set, max_completion_tokens should be None assert_eq!(body.max_tokens, Some(100)); assert!(body.max_completion_tokens.is_none()); @@ -317,12 +309,10 @@ mod tests { #[test] fn test_build_request_modern_max_tokens() { let scheme = OpenAIScheme::new(); // Default matches modern (legacy=false) - let request = Request::new() - .user("Hello") - .max_tokens(100); + let request = Request::new().user("Hello").max_tokens(100); let body = scheme.build_request("gpt-4o", &request); - + // max_completion_tokens should be set, max_tokens should be None assert_eq!(body.max_completion_tokens, Some(100)); assert!(body.max_tokens.is_none()); diff --git a/worker/src/timeline.rs b/worker/src/timeline.rs index 87cc79f..72afa07 100644 --- a/worker/src/timeline.rs +++ b/worker/src/timeline.rs @@ -502,13 +502,13 @@ impl Timeline { fn handle_block_delta(&mut self, delta: &BlockDelta) { let block_type = delta.delta.block_type(); - + // OpenAIなどのプロバイダはBlockStartを送らない場合があるため、 // Deltaが来たときにスコープがなければ暗黙的に開始する if self.current_block.is_none() { self.current_block = Some(block_type); } - + let handlers = self.get_block_handlers_mut(block_type); for handler in handlers { // スコープがなければ暗黙的に開始 diff --git a/worker/src/worker.rs b/worker/src/worker.rs index 7ba3ada..38ba751 100644 --- a/worker/src/worker.rs +++ b/worker/src/worker.rs @@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex}; use futures::StreamExt; use tracing::{debug, info, trace, warn}; +use crate::Timeline; use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition}; use crate::subscriber_adapter::{ ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter, @@ -11,7 +12,6 @@ use crate::subscriber_adapter::{ }; use crate::text_block_collector::TextBlockCollector; use crate::tool_call_collector::ToolCallCollector; -use crate::Timeline; use worker_types::{ ContentPart, ControlFlow, HookError, Message, MessageContent, Tool, ToolCall, ToolError, ToolResult, TurnResult, WorkerHook, WorkerSubscriber, @@ -223,7 +223,7 @@ impl Worker { pub async fn run(&mut self, messages: Vec) -> Result, WorkerError> { let mut context = messages; let tool_definitions = self.build_tool_definitions(); - + info!( message_count = context.len(), tool_count = tool_definitions.len(), @@ -442,10 +442,7 @@ impl Worker { } /// Hooks: on_turn_end - async fn run_on_turn_end_hooks( - &self, - messages: &[Message], - ) -> Result { + async fn run_on_turn_end_hooks(&self, messages: &[Message]) -> Result { for hook in &self.hooks { let result = hook.on_turn_end(messages).await?; match result { diff --git a/worker/tests/common/mod.rs b/worker/tests/common/mod.rs index d701bc8..018f54a 100644 --- a/worker/tests/common/mod.rs +++ b/worker/tests/common/mod.rs @@ -3,13 +3,13 @@ use std::fs::File; use std::io::{BufRead, BufReader}; use std::path::{Path, PathBuf}; -use std::sync::{Arc, Mutex}; use std::pin::Pin; +use std::sync::{Arc, Mutex}; use async_trait::async_trait; use futures::Stream; -use worker::{Handler, TextBlockEvent, TextBlockKind, Timeline}; use worker::llm_client::{ClientError, LlmClient, Request}; +use worker::{Handler, TextBlockEvent, TextBlockKind, Timeline}; use worker_types::{BlockType, DeltaContent, Event}; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -51,11 +51,11 @@ impl LlmClient for MockLlmClient { ) -> Result> + Send>>, ClientError> { let count = self.call_count.fetch_add(1, Ordering::SeqCst); if count >= self.responses.len() { - return Err(ClientError::Api { - status: Some(500), - code: Some("mock_error".to_string()), - message: "No more mock responses".to_string(), - }); + return Err(ClientError::Api { + status: Some(500), + code: Some("mock_error".to_string()), + message: "No more mock responses".to_string(), + }); } let events = self.responses[count].clone(); let stream = futures::stream::iter(events.into_iter().map(Ok)); @@ -135,7 +135,8 @@ pub fn assert_event_sequence(subdir: &str) { } // Find a text-based fixture - let fixture_path = fixtures.iter() + let fixture_path = fixtures + .iter() .find(|p| p.to_string_lossy().contains("text")) .unwrap_or(&fixtures[0]); @@ -156,9 +157,9 @@ pub fn assert_event_sequence(subdir: &str) { } } Event::BlockDelta(delta) => { - if let DeltaContent::Text(_) = &delta.delta { - delta_found = true; - } + if let DeltaContent::Text(_) = &delta.delta { + delta_found = true; + } } Event::BlockStop(stop) => { if stop.block_type == BlockType::Text { @@ -173,9 +174,9 @@ pub fn assert_event_sequence(subdir: &str) { // Check for BlockStart (Warn only for OpenAI/Ollama as it might be missing for text) if !start_found { - println!("Warning: No BlockStart found. This is common for OpenAI/Ollama text streams."); - // For Anthropic, strict start is usually expected, but to keep common logic simple we allow warning. - // If specific strictness is needed, we could add a `strict: bool` arg. + println!("Warning: No BlockStart found. This is common for OpenAI/Ollama text streams."); + // For Anthropic, strict start is usually expected, but to keep common logic simple we allow warning. + // If specific strictness is needed, we could add a `strict: bool` arg. } assert!(delta_found, "Should contain BlockDelta"); @@ -184,7 +185,9 @@ pub fn assert_event_sequence(subdir: &str) { assert!(stop_found, "Should contain BlockStop for Text block"); } else { if !stop_found { - println!(" [Type: ToolUse] BlockStop detection skipped (not explicitly emitted by scheme)"); + println!( + " [Type: ToolUse] BlockStop detection skipped (not explicitly emitted by scheme)" + ); } } } @@ -200,13 +203,23 @@ pub fn assert_usage_tokens(subdir: &str) { let events = load_events_from_fixture(&fixture); let usage_events: Vec<_> = events .iter() - .filter_map(|e| if let Event::Usage(u) = e { Some(u) } else { None }) + .filter_map(|e| { + if let Event::Usage(u) = e { + Some(u) + } else { + None + } + }) .collect(); if !usage_events.is_empty() { let last_usage = usage_events.last().unwrap(); if last_usage.input_tokens.is_some() || last_usage.output_tokens.is_some() { - println!(" Fixture {:?} Usage: {:?}", fixture.file_name(), last_usage); + println!( + " Fixture {:?} Usage: {:?}", + fixture.file_name(), + last_usage + ); return; // Found valid usage } } @@ -221,7 +234,8 @@ pub fn assert_timeline_integration(subdir: &str) { return; } - let fixture_path = fixtures.iter() + let fixture_path = fixtures + .iter() .find(|p| p.to_string_lossy().contains("text")) .unwrap_or(&fixtures[0]); diff --git a/worker/tests/parallel_execution_test.rs b/worker/tests/parallel_execution_test.rs index deb0715..49888f6 100644 --- a/worker/tests/parallel_execution_test.rs +++ b/worker/tests/parallel_execution_test.rs @@ -2,13 +2,16 @@ //! //! Workerが複数のツールを並列に実行することを確認する。 -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::{Duration, Instant}; use async_trait::async_trait; use worker::Worker; -use worker_types::{Event, Message, ResponseStatus, StatusEvent, Tool, ToolError, ToolResult, ToolCall, ControlFlow, HookError, WorkerHook}; +use worker_types::{ + ControlFlow, Event, HookError, Message, ResponseStatus, StatusEvent, Tool, ToolCall, ToolError, + ToolResult, WorkerHook, +}; mod common; use common::MockLlmClient; @@ -105,8 +108,6 @@ async fn test_parallel_tool_execution() { worker.register_tool(tool2); worker.register_tool(tool3); - - let messages = vec![Message::user("Run all tools")]; let start = Instant::now(); @@ -161,7 +162,10 @@ async fn test_before_tool_call_skip() { #[async_trait] impl WorkerHook for BlockingHook { - async fn before_tool_call(&self, tool_call: &mut ToolCall) -> Result { + async fn before_tool_call( + &self, + tool_call: &mut ToolCall, + ) -> Result { if tool_call.name == "blocked_tool" { Ok(ControlFlow::Skip) } else { @@ -176,8 +180,16 @@ async fn test_before_tool_call_skip() { let _result = worker.run(messages).await; // allowed_tool は呼び出されるが、blocked_tool は呼び出されない - assert_eq!(allowed_clone.call_count(), 1, "Allowed tool should be called"); - assert_eq!(blocked_clone.call_count(), 0, "Blocked tool should not be called"); + assert_eq!( + allowed_clone.call_count(), + 1, + "Allowed tool should be called" + ); + assert_eq!( + blocked_clone.call_count(), + 0, + "Blocked tool should not be called" + ); } /// Hook: after_tool_call で結果が改変されることを確認 @@ -212,9 +224,15 @@ async fn test_after_tool_call_modification() { #[async_trait] impl Tool for SimpleTool { - fn name(&self) -> &str { "test_tool" } - fn description(&self) -> &str { "Test" } - fn input_schema(&self) -> serde_json::Value { serde_json::json!({}) } + fn name(&self) -> &str { + "test_tool" + } + fn description(&self) -> &str { + "Test" + } + fn input_schema(&self) -> serde_json::Value { + serde_json::json!({}) + } async fn execute(&self, _: &str) -> Result { Ok("Original Result".to_string()) } @@ -229,7 +247,10 @@ async fn test_after_tool_call_modification() { #[async_trait] impl WorkerHook for ModifyingHook { - async fn after_tool_call(&self, tool_result: &mut ToolResult) -> Result { + async fn after_tool_call( + &self, + tool_result: &mut ToolResult, + ) -> Result { tool_result.content = format!("[Modified] {}", tool_result.content); *self.modified_content.lock().unwrap() = Some(tool_result.content.clone()); Ok(ControlFlow::Continue) @@ -237,7 +258,9 @@ async fn test_after_tool_call_modification() { } let modified_content = Arc::new(std::sync::Mutex::new(None)); - worker.add_hook(ModifyingHook { modified_content: modified_content.clone() }); + worker.add_hook(ModifyingHook { + modified_content: modified_content.clone(), + }); let messages = vec![Message::user("Test modification")]; let result = worker.run(messages).await; diff --git a/worker/tests/tool_macro_test.rs b/worker/tests/tool_macro_test.rs index 98e1d00..9cb6d4d 100644 --- a/worker/tests/tool_macro_test.rs +++ b/worker/tests/tool_macro_test.rs @@ -2,8 +2,8 @@ //! //! `#[tool_registry]` と `#[tool]` マクロの動作を確認する。 -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; // マクロ展開に必要なインポート use schemars; @@ -59,12 +59,19 @@ async fn test_basic_tool_generation() { // 説明の確認(docコメントから取得) let desc = greet_tool.description(); - assert!(desc.contains("メッセージに挨拶を追加する"), "Description should contain doc comment: {}", desc); + assert!( + desc.contains("メッセージに挨拶を追加する"), + "Description should contain doc comment: {}", + desc + ); // スキーマの確認 let schema = greet_tool.input_schema(); println!("Schema: {}", serde_json::to_string_pretty(&schema).unwrap()); - assert!(schema.get("properties").is_some(), "Schema should have properties"); + assert!( + schema.get("properties").is_some(), + "Schema should have properties" + ); // 実行テスト let result = greet_tool.execute(r#"{"message": "World"}"#).await; @@ -104,7 +111,11 @@ async fn test_no_arguments() { let result = get_prefix_tool.execute(r#"{}"#).await; assert!(result.is_ok()); let output = result.unwrap(); - assert!(output.contains("TestPrefix"), "Should contain prefix: {}", output); + assert!( + output.contains("TestPrefix"), + "Should contain prefix: {}", + output + ); } #[tokio::test] @@ -169,7 +180,11 @@ async fn test_result_return_type_error() { assert!(result.is_err(), "Should fail for negative value"); let err = result.unwrap_err(); - assert!(err.to_string().contains("positive"), "Error should mention positive: {}", err); + assert!( + err.to_string().contains("positive"), + "Error should mention positive: {}", + err + ); } // ============================================================================= diff --git a/worker/tests/worker_fixtures.rs b/worker/tests/worker_fixtures.rs index b8a0d47..b8d7a1a 100644 --- a/worker/tests/worker_fixtures.rs +++ b/worker/tests/worker_fixtures.rs @@ -6,8 +6,8 @@ mod common; use std::path::Path; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use async_trait::async_trait; use common::MockLlmClient; @@ -67,9 +67,7 @@ impl Tool for MockWeatherTool { let input: serde_json::Value = serde_json::from_str(input_json) .map_err(|e| ToolError::InvalidArgument(e.to_string()))?; - let city = input["city"] - .as_str() - .unwrap_or("Unknown"); + let city = input["city"].as_str().unwrap_or("Unknown"); // モックのレスポンスを返す Ok(format!("Weather in {}: Sunny, 22°C", city)) @@ -163,8 +161,6 @@ async fn test_worker_tool_call() { let tool_for_check = weather_tool.clone(); worker.register_tool(weather_tool); - - // メッセージを送信 let messages = vec![worker_types::Message::user("What's the weather in Tokyo?")]; let _result = worker.run(messages).await; @@ -212,8 +208,8 @@ async fn test_worker_with_programmatic_events() { /// id, name, input(JSON)を正しく抽出できることを検証する。 #[tokio::test] async fn test_tool_call_collector_integration() { - use worker::ToolCallCollector; use worker::Timeline; + use worker::ToolCallCollector; use worker_types::Event; // ToolUseブロックを含むイベントシーケンス