fmt: cargo fmt

This commit is contained in:
Keisuke Hirata 2026-01-07 22:04:44 +09:00
parent bb73dc6a45
commit 1e126c1698
20 changed files with 263 additions and 227 deletions

View File

@ -6,7 +6,7 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::{ use syn::{
parse_macro_input, Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, parse_macro_input,
}; };
/// `impl` ブロックに付与し、内部の `#[tool]` 属性がついたメソッドからツールを生成するマクロ。 /// `impl` ブロックに付与し、内部の `#[tool]` 属性がついたメソッドからツールを生成するマクロ。
@ -311,7 +311,7 @@ pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
} }
/// 引数属性用のマーカー。パース時に`tool_registry`で解釈される。 /// 引数属性用のマーカー。パース時に`tool_registry`で解釈される。
/// ///
/// # Example /// # Example
/// ```ignore /// ```ignore
/// #[tool] /// #[tool]

View File

@ -127,7 +127,10 @@ pub trait WorkerHook: Send + Sync {
/// ツール実行後 /// ツール実行後
/// ///
/// 結果を書き換えたり、隠蔽したりできる。 /// 結果を書き換えたり、隠蔽したりできる。
async fn after_tool_call(&self, _tool_result: &mut ToolResult) -> Result<ControlFlow, HookError> { async fn after_tool_call(
&self,
_tool_result: &mut ToolResult,
) -> Result<ControlFlow, HookError> {
Ok(ControlFlow::Continue) Ok(ControlFlow::Continue)
} }

View File

@ -54,7 +54,10 @@ pub enum ContentPart {
}, },
/// ツール結果 /// ツール結果
#[serde(rename = "tool_result")] #[serde(rename = "tool_result")]
ToolResult { tool_use_id: String, content: String }, ToolResult {
tool_use_id: String,
content: String,
},
} }
impl Message { impl Message {

View File

@ -3,9 +3,7 @@
//! Timeline層のHandler機構の薄いラッパーとして設計され、 //! Timeline層のHandler機構の薄いラッパーとして設計され、
//! UIへのストリーミング表示やリアルタイムフィードバックを可能にする。 //! UIへのストリーミング表示やリアルタイムフィードバックを可能にする。
use crate::{ use crate::{ErrorEvent, StatusEvent, TextBlockEvent, ToolCall, ToolUseBlockEvent, UsageEvent};
ErrorEvent, StatusEvent, TextBlockEvent, ToolCall, ToolUseBlockEvent, UsageEvent,
};
// ============================================================================= // =============================================================================
// WorkerSubscriber Trait // WorkerSubscriber Trait
@ -74,7 +72,11 @@ pub trait WorkerSubscriber: Send {
/// ///
/// Start/InputJsonDelta/Stopのライフサイクルを持つ。 /// Start/InputJsonDelta/Stopのライフサイクルを持つ。
#[allow(unused_variables)] #[allow(unused_variables)]
fn on_tool_use_block(&mut self, scope: &mut Self::ToolUseBlockScope, event: &ToolUseBlockEvent) { fn on_tool_use_block(
&mut self,
scope: &mut Self::ToolUseBlockScope,
event: &ToolUseBlockEvent,
) {
} }
// ========================================================================= // =========================================================================

View File

@ -111,8 +111,8 @@ impl Handler<UsageKind> for UsageTracker {
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
// APIキーを環境変数から取得 // APIキーを環境変数から取得
let api_key = std::env::var("GEMINI_API_KEY") let api_key =
.expect("GEMINI_API_KEY environment variable must be set"); std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY environment variable must be set");
println!("=== Gemini LLM Client + Timeline Integration Example ===\n"); println!("=== Gemini LLM Client + Timeline Integration Example ===\n");

View File

@ -16,9 +16,6 @@
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- --all //! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- --all
//! ``` //! ```
mod recorder; mod recorder;
mod scenarios; mod scenarios;
@ -82,7 +79,8 @@ async fn run_scenario_with_openai(
subdir: &str, subdir: &str,
model: Option<String>, model: Option<String>,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY environment variable must be set"); let api_key =
std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY environment variable must be set");
let model = model.as_deref().unwrap_or("gpt-4o"); let model = model.as_deref().unwrap_or("gpt-4o");
let client = OpenAIClient::new(&api_key, model); let client = OpenAIClient::new(&api_key, model);
@ -125,8 +123,8 @@ async fn run_scenario_with_gemini(
subdir: &str, subdir: &str,
model: Option<String>, model: Option<String>,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let api_key = std::env::var("GEMINI_API_KEY") let api_key =
.expect("GEMINI_API_KEY environment variable must be set"); std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY environment variable must be set");
let model = model.as_deref().unwrap_or("gemini-2.0-flash"); let model = model.as_deref().unwrap_or("gemini-2.0-flash");
let client = GeminiClient::new(&api_key, model); let client = GeminiClient::new(&api_key, model);
@ -142,9 +140,6 @@ async fn run_scenario_with_gemini(
Ok(()) Ok(())
} }
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
dotenv::dotenv().ok(); dotenv::dotenv().ok();
@ -173,13 +168,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.collect(); .collect();
if found.is_empty() { if found.is_empty() {
eprintln!("Error: Unknown scenario '{}'", scenario_name); eprintln!("Error: Unknown scenario '{}'", scenario_name);
// Verify correct name by listing // Verify correct name by listing
println!("Available scenarios:"); println!("Available scenarios:");
for s in scenarios::scenarios() { for s in scenarios::scenarios() {
println!(" {}", s.output_name); println!(" {}", s.output_name);
} }
std::process::exit(1); std::process::exit(1);
} }
found found
}; };
@ -201,12 +196,20 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// シナリオのフィルタリングは main.rs のロジックで実行済み // シナリオのフィルタリングは main.rs のロジックで実行済み
// ここでは単純なループで実行 // ここでは単純なループで実行
for scenario in scenarios_to_run { for scenario in scenarios_to_run {
match args.client { match args.client {
ClientType::Anthropic => run_scenario_with_anthropic(&scenario, subdir, args.model.clone()).await?, ClientType::Anthropic => {
ClientType::Gemini => run_scenario_with_gemini(&scenario, subdir, args.model.clone()).await?, run_scenario_with_anthropic(&scenario, subdir, args.model.clone()).await?
ClientType::Openai => run_scenario_with_openai(&scenario, subdir, args.model.clone()).await?, }
ClientType::Ollama => run_scenario_with_ollama(&scenario, subdir, args.model.clone()).await?, ClientType::Gemini => {
} run_scenario_with_gemini(&scenario, subdir, args.model.clone()).await?
}
ClientType::Openai => {
run_scenario_with_openai(&scenario, subdir, args.model.clone()).await?
}
ClientType::Ollama => {
run_scenario_with_ollama(&scenario, subdir, args.model.clone()).await?
}
}
} }
println!("\n✅ Done!"); println!("\n✅ Done!");

View File

@ -38,14 +38,14 @@ use tracing_subscriber::EnvFilter;
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use worker::{ use worker::{
Handler, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind, Worker,
llm_client::{ llm_client::{
LlmClient,
providers::{ providers::{
anthropic::AnthropicClient, gemini::GeminiClient, ollama::OllamaClient, anthropic::AnthropicClient, gemini::GeminiClient, ollama::OllamaClient,
openai::OpenAIClient, openai::OpenAIClient,
}, },
LlmClient,
}, },
Handler, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind, Worker,
}; };
use worker_macros::tool_registry; use worker_macros::tool_registry;
use worker_types::Message; use worker_types::Message;
@ -310,9 +310,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// ロギング初期化 // ロギング初期化
// RUST_LOG=debug cargo run --example worker_cli ... で詳細ログ表示 // RUST_LOG=debug cargo run --example worker_cli ... で詳細ログ表示
// デフォルトは warn レベル、RUST_LOG 環境変数で上書き可能 // デフォルトは warn レベル、RUST_LOG 環境変数で上書き可能
let filter = EnvFilter::try_from_default_env() let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("warn"));
.unwrap_or_else(|_| EnvFilter::new("warn"));
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_env_filter(filter) .with_env_filter(filter)
.with_target(true) .with_target(true)
@ -320,7 +319,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// CLI引数をパース // CLI引数をパース
let args = Args::parse(); let args = Args::parse();
info!( info!(
provider = ?args.provider, provider = ?args.provider,
model = ?args.model, model = ?args.model,

View File

@ -6,7 +6,7 @@ use std::pin::Pin;
use async_trait::async_trait; use async_trait::async_trait;
use eventsource_stream::Eventsource; use eventsource_stream::Eventsource;
use futures::{future::ready, Stream, StreamExt, TryStreamExt}; use futures::{Stream, StreamExt, TryStreamExt, future::ready};
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue}; use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
use worker_types::Event; use worker_types::Event;
@ -178,7 +178,6 @@ impl LlmClient for AnthropicClient {
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -10,9 +10,7 @@ use futures::Stream;
use worker_types::Event; use worker_types::Event;
use crate::llm_client::{ use crate::llm_client::{
ClientError, LlmClient, Request, ClientError, LlmClient, Request, providers::openai::OpenAIClient, scheme::openai::OpenAIScheme,
providers::openai::OpenAIClient,
scheme::openai::OpenAIScheme,
}; };
/// Ollama クライアント /// Ollama クライアント
@ -29,7 +27,7 @@ impl OllamaClient {
// Ollama usually runs on localhost:11434/v1 // Ollama usually runs on localhost:11434/v1
// API key is "ollama" or ignored // API key is "ollama" or ignored
let base_url = "http://localhost:11434"; let base_url = "http://localhost:11434";
let scheme = OpenAIScheme::new().with_legacy_max_tokens(true); let scheme = OpenAIScheme::new().with_legacy_max_tokens(true);
let client = OpenAIClient::new("ollama", model) let client = OpenAIClient::new("ollama", model)
@ -37,7 +35,7 @@ impl OllamaClient {
.with_scheme(scheme); .with_scheme(scheme);
// Currently OpenAIScheme sets include_usage: true. Ollama supports checks? // Currently OpenAIScheme sets include_usage: true. Ollama supports checks?
// Assuming Ollama modern versions support usage. // Assuming Ollama modern versions support usage.
Self { inner: client } Self { inner: client }
} }
@ -46,7 +44,7 @@ impl OllamaClient {
self.inner = self.inner.with_base_url(url); self.inner = self.inner.with_base_url(url);
self self
} }
/// カスタムHTTPクライアントを設定 /// カスタムHTTPクライアントを設定
pub fn with_http_client(mut self, client: reqwest::Client) -> Self { pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
self.inner = self.inner.with_http_client(client); self.inner = self.inner.with_http_client(client);

View File

@ -61,21 +61,21 @@ impl OpenAIClient {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let api_key_val = if self.api_key.is_empty() { let api_key_val = if self.api_key.is_empty() {
// For providers like Ollama, API key might be empty/dummy. // For providers like Ollama, API key might be empty/dummy.
// But typical OpenAI requires it. // But typical OpenAI requires it.
// We'll allow empty if user intends it, but usually it's checked. // We'll allow empty if user intends it, but usually it's checked.
HeaderValue::from_static("") HeaderValue::from_static("")
} else { } else {
let mut val = HeaderValue::from_str(&format!("Bearer {}", self.api_key)) let mut val = HeaderValue::from_str(&format!("Bearer {}", self.api_key))
.map_err(|e| ClientError::Config(format!("Invalid API key: {}", e)))?; .map_err(|e| ClientError::Config(format!("Invalid API key: {}", e)))?;
val.set_sensitive(true); val.set_sensitive(true);
val val
}; };
if !api_key_val.is_empty() { if !api_key_val.is_empty() {
headers.insert("Authorization", api_key_val); headers.insert("Authorization", api_key_val);
} }
Ok(headers) Ok(headers)
@ -92,24 +92,24 @@ impl LlmClient for OpenAIClient {
// Standard OpenAI base is "https://api.openai.com". Endpoint is "/v1/chat/completions". // Standard OpenAI base is "https://api.openai.com". Endpoint is "/v1/chat/completions".
// If external base_url includes /v1, we should be careful. // If external base_url includes /v1, we should be careful.
// Let's assume defaults. If user provides "http://localhost:11434/v1", we append "/chat/completions". // Let's assume defaults. If user provides "http://localhost:11434/v1", we append "/chat/completions".
// Or cleaner: user provides full base up to version? // Or cleaner: user provides full base up to version?
// Anthropic client uses "{}/v1/messages". // Anthropic client uses "{}/v1/messages".
// Let's stick to appending "/v1/chat/completions" if base is just host, // Let's stick to appending "/v1/chat/completions" if base is just host,
// OR assume base includes /v1 if user overrides it? // OR assume base includes /v1 if user overrides it?
// Let's use robust joining or simple assumption matching Anthropic pattern: // Let's use robust joining or simple assumption matching Anthropic pattern:
// Default: https://api.openai.com -> https://api.openai.com/v1/chat/completions // Default: https://api.openai.com -> https://api.openai.com/v1/chat/completions
// However, Ollama default is http://localhost:11434/v1/chat/completions if using OpenAI compact. // However, Ollama default is http://localhost:11434/v1/chat/completions if using OpenAI compact.
// If we configure base_url via `with_base_url`, it's flexible. // If we configure base_url via `with_base_url`, it's flexible.
// Let's try to detect if /v1 is present or just append consistently. // Let's try to detect if /v1 is present or just append consistently.
// Ideally `base_url` should be the root passed to `new`. // Ideally `base_url` should be the root passed to `new`.
let url = if self.base_url.ends_with("/v1") { let url = if self.base_url.ends_with("/v1") {
format!("{}/chat/completions", self.base_url) format!("{}/chat/completions", self.base_url)
} else if self.base_url.ends_with("/") { } else if self.base_url.ends_with("/") {
format!("{}v1/chat/completions", self.base_url) format!("{}v1/chat/completions", self.base_url)
} else { } else {
format!("{}/v1/chat/completions", self.base_url) format!("{}/v1/chat/completions", self.base_url)
}; };
let headers = self.build_headers()?; let headers = self.build_headers()?;
@ -159,40 +159,41 @@ impl LlmClient for OpenAIClient {
.map_err(|e| std::io::Error::other(e)); .map_err(|e| std::io::Error::other(e));
let event_stream = byte_stream.eventsource(); let event_stream = byte_stream.eventsource();
let stream = event_stream.map(move |result| { let stream = event_stream
match result { .map(move |result| {
Ok(event) => { match result {
// SSEイベントをパース Ok(event) => {
// OpenAI stream events are "data: {...}" // SSEイベントをパース
// event.event is usually "message" (default) or empty. // OpenAI stream events are "data: {...}"
// parse_event takes data string. // event.event is usually "message" (default) or empty.
// parse_event takes data string.
if event.data == "[DONE]" {
// End of stream handled inside parse_event usually returning None if event.data == "[DONE]" {
Ok(None) // End of stream handled inside parse_event usually returning None
} else { Ok(None)
match scheme.parse_event(&event.data) { } else {
Ok(Some(events)) => Ok(Some(events)), match scheme.parse_event(&event.data) {
Ok(None) => Ok(None), Ok(Some(events)) => Ok(Some(events)),
Err(e) => Err(e), Ok(None) => Ok(None),
Err(e) => Err(e),
}
} }
} }
Err(e) => Err(ClientError::Sse(e.to_string())),
} }
Err(e) => Err(ClientError::Sse(e.to_string())), })
} // flatten Option<Vec<Event>> stream to Stream<Event>
}) // map returns Result<Option<Vec<Event>>, Error>
// flatten Option<Vec<Event>> stream to Stream<Event> // We want Stream<Item = Result<Event, Error>>
// map returns Result<Option<Vec<Event>>, Error> .map(|res| {
// We want Stream<Item = Result<Event, Error>> let s: Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>> = match res {
.map(|res| { Ok(Some(events)) => Box::pin(futures::stream::iter(events.into_iter().map(Ok))),
let s: Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>> = match res { Ok(None) => Box::pin(futures::stream::empty()),
Ok(Some(events)) => Box::pin(futures::stream::iter(events.into_iter().map(Ok))), Err(e) => Box::pin(futures::stream::once(async move { Err(e) })),
Ok(None) => Box::pin(futures::stream::empty()), };
Err(e) => Box::pin(futures::stream::once(async move { Err(e) })), s
}; })
s .flatten();
})
.flatten();
Ok(Box::pin(stream)) Ok(Box::pin(stream))
} }

View File

@ -127,13 +127,12 @@ impl GeminiScheme {
return Ok(None); return Ok(None);
} }
let response: GenerateContentResponse = serde_json::from_str(data).map_err(|e| { let response: GenerateContentResponse =
ClientError::Api { serde_json::from_str(data).map_err(|e| ClientError::Api {
status: None, status: None,
code: Some("parse_error".to_string()), code: Some("parse_error".to_string()),
message: format!("Failed to parse Gemini SSE data: {} -> {}", e, data), message: format!("Failed to parse Gemini SSE data: {} -> {}", e, data),
} })?;
})?;
let mut events = Vec::new(); let mut events = Vec::new();
@ -155,10 +154,7 @@ impl GeminiScheme {
if !text.is_empty() { if !text.is_empty() {
// Geminiは明示的なBlockStartを送らないため、 // Geminiは明示的なBlockStartを送らないため、
// TextDeltaを直接送るTimelineが暗黙的に開始を処理 // TextDeltaを直接送るTimelineが暗黙的に開始を処理
events.push(Event::text_delta( events.push(Event::text_delta(part_index, text.clone()));
part_index,
text.clone(),
));
} }
} }
@ -167,10 +163,10 @@ impl GeminiScheme {
// 関数呼び出しの開始 // 関数呼び出しの開始
// Geminiでは関数呼び出しは一度に送られることが多い // Geminiでは関数呼び出しは一度に送られることが多い
// ストリーミング引数が有効な場合は部分的に送られる可能性がある // ストリーミング引数が有効な場合は部分的に送られる可能性がある
// 関数呼び出しIDはGeminiにはないので、名前をIDとして使用 // 関数呼び出しIDはGeminiにはないので、名前をIDとして使用
let function_id = format!("call_{}", function_call.name); let function_id = format!("call_{}", function_call.name);
events.push(Event::BlockStart(BlockStart { events.push(Event::BlockStart(BlockStart {
index: candidate_index * 10 + part_index, // 複合インデックス index: candidate_index * 10 + part_index, // 複合インデックス
block_type: BlockType::ToolUse, block_type: BlockType::ToolUse,
@ -240,7 +236,8 @@ mod tests {
#[test] #[test]
fn test_parse_text_response() { fn test_parse_text_response() {
let scheme = GeminiScheme::new(); let scheme = GeminiScheme::new();
let data = r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"index":0}]}"#; let data =
r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"index":0}]}"#;
let events = scheme.parse_event(data).unwrap().unwrap(); let events = scheme.parse_event(data).unwrap().unwrap();
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
@ -263,7 +260,7 @@ mod tests {
let data = r#"{"candidates":[{"content":{"parts":[{"text":"Hi"}],"role":"model"},"index":0}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}"#; let data = r#"{"candidates":[{"content":{"parts":[{"text":"Hi"}],"role":"model"},"index":0}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}"#;
let events = scheme.parse_event(data).unwrap().unwrap(); let events = scheme.parse_event(data).unwrap().unwrap();
// Usageイベントが含まれるはず // Usageイベントが含まれるはず
let usage_event = events.iter().find(|e| matches!(e, Event::Usage(_))); let usage_event = events.iter().find(|e| matches!(e, Event::Usage(_)));
assert!(usage_event.is_some()); assert!(usage_event.is_some());
@ -281,7 +278,7 @@ mod tests {
let data = r#"{"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_weather","args":{"location":"Tokyo"}}}],"role":"model"},"index":0}]}"#; let data = r#"{"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_weather","args":{"location":"Tokyo"}}}],"role":"model"},"index":0}]}"#;
let events = scheme.parse_event(data).unwrap().unwrap(); let events = scheme.parse_event(data).unwrap().unwrap();
// BlockStartイベントがあるはず // BlockStartイベントがあるはず
let start_event = events.iter().find(|e| matches!(e, Event::BlockStart(_))); let start_event = events.iter().find(|e| matches!(e, Event::BlockStart(_)));
assert!(start_event.is_some()); assert!(start_event.is_some());
@ -312,7 +309,7 @@ mod tests {
let data = r#"{"candidates":[{"content":{"parts":[{"text":"Done"}],"role":"model"},"finishReason":"STOP","index":0}]}"#; let data = r#"{"candidates":[{"content":{"parts":[{"text":"Done"}],"role":"model"},"finishReason":"STOP","index":0}]}"#;
let events = scheme.parse_event(data).unwrap().unwrap(); let events = scheme.parse_event(data).unwrap().unwrap();
// BlockStopイベントがあるはず // BlockStopイベントがあるはず
let stop_event = events.iter().find(|e| matches!(e, Event::BlockStop(_))); let stop_event = events.iter().find(|e| matches!(e, Event::BlockStop(_)));
assert!(stop_event.is_some()); assert!(stop_event.is_some());

View File

@ -46,9 +46,7 @@ pub(crate) struct GeminiContent {
#[serde(untagged)] #[serde(untagged)]
pub(crate) enum GeminiPart { pub(crate) enum GeminiPart {
/// テキストパーツ /// テキストパーツ
Text { Text { text: String },
text: String,
},
/// 関数呼び出しパーツ /// 関数呼び出しパーツ
FunctionCall { FunctionCall {
#[serde(rename = "functionCall")] #[serde(rename = "functionCall")]
@ -160,11 +158,7 @@ impl GeminiScheme {
vec![] vec![]
} else { } else {
vec![GeminiTool { vec![GeminiTool {
function_declarations: request function_declarations: request.tools.iter().map(|t| self.convert_tool(t)).collect(),
.tools
.iter()
.map(|t| self.convert_tool(t))
.collect(),
}] }]
}; };
@ -224,34 +218,30 @@ impl GeminiScheme {
}, },
}] }]
} }
MessageContent::Parts(parts) => { MessageContent::Parts(parts) => parts
parts .iter()
.iter() .map(|p| match p {
.map(|p| match p { ContentPart::Text { text } => GeminiPart::Text { text: text.clone() },
ContentPart::Text { text } => GeminiPart::Text { text: text.clone() }, ContentPart::ToolUse { id: _, name, input } => GeminiPart::FunctionCall {
ContentPart::ToolUse { id: _, name, input } => { function_call: GeminiFunctionCall {
GeminiPart::FunctionCall { name: name.clone(),
function_call: GeminiFunctionCall { args: input.clone(),
name: name.clone(), },
args: input.clone(), },
}, ContentPart::ToolResult {
} tool_use_id,
} content,
ContentPart::ToolResult { } => GeminiPart::FunctionResponse {
tool_use_id, function_response: GeminiFunctionResponse {
content, name: tool_use_id.clone(),
} => GeminiPart::FunctionResponse { response: GeminiFunctionResponseContent {
function_response: GeminiFunctionResponse {
name: tool_use_id.clone(), name: tool_use_id.clone(),
response: GeminiFunctionResponseContent { content: serde_json::Value::String(content.clone()),
name: tool_use_id.clone(),
content: serde_json::Value::String(content.clone()),
},
}, },
}, },
}) },
.collect() })
} .collect(),
}; };
GeminiContent { GeminiContent {
@ -306,16 +296,17 @@ mod tests {
assert_eq!(gemini_req.tools.len(), 1); assert_eq!(gemini_req.tools.len(), 1);
assert_eq!(gemini_req.tools[0].function_declarations.len(), 1); assert_eq!(gemini_req.tools[0].function_declarations.len(), 1);
assert_eq!(gemini_req.tools[0].function_declarations[0].name, "get_weather"); assert_eq!(
gemini_req.tools[0].function_declarations[0].name,
"get_weather"
);
assert!(gemini_req.tool_config.is_some()); assert!(gemini_req.tool_config.is_some());
} }
#[test] #[test]
fn test_assistant_role_is_model() { fn test_assistant_role_is_model() {
let scheme = GeminiScheme::new(); let scheme = GeminiScheme::new();
let request = Request::new() let request = Request::new().user("Hello").assistant("Hi there!");
.user("Hello")
.assistant("Hi there!");
let gemini_req = scheme.build_request(&request); let gemini_req = scheme.build_request(&request);

View File

@ -69,8 +69,8 @@ impl OpenAIScheme {
return Ok(None); return Ok(None);
} }
let chunk: ChatCompletionChunk = serde_json::from_str(data) let chunk: ChatCompletionChunk =
.map_err(|e| ClientError::Api { serde_json::from_str(data).map_err(|e| ClientError::Api {
status: None, status: None,
code: Some("parse_error".to_string()), code: Some("parse_error".to_string()),
message: format!("Failed to parse SSE data: {} -> {}", e, data), message: format!("Failed to parse SSE data: {} -> {}", e, data),
@ -102,10 +102,14 @@ impl OpenAIScheme {
for tool_call in tool_calls { for tool_call in tool_calls {
// Start of tool call (has ID) // Start of tool call (has ID)
if let Some(id) = tool_call.id { if let Some(id) = tool_call.id {
let name = tool_call.function.as_ref().and_then(|f| f.name.clone()).unwrap_or_default(); let name = tool_call
.function
.as_ref()
.and_then(|f| f.name.clone())
.unwrap_or_default();
events.push(Event::tool_use_start(tool_call.index, id, name)); events.push(Event::tool_use_start(tool_call.index, id, name));
} }
// Arguments delta // Arguments delta
if let Some(function) = tool_call.function { if let Some(function) = tool_call.function {
if let Some(args) = function.arguments { if let Some(args) = function.arguments {
@ -116,7 +120,7 @@ impl OpenAIScheme {
} }
} }
} }
// Finish Reason // Finish Reason
if let Some(finish_reason) = choice.finish_reason { if let Some(finish_reason) = choice.finish_reason {
let stop_reason = match finish_reason.as_str() { let stop_reason = match finish_reason.as_str() {
@ -125,9 +129,10 @@ impl OpenAIScheme {
"tool_calls" | "function_call" => Some(StopReason::ToolUse), "tool_calls" | "function_call" => Some(StopReason::ToolUse),
_ => Some(StopReason::EndTurn), _ => Some(StopReason::EndTurn),
}; };
let is_tool_finish = finish_reason == "tool_calls" || finish_reason == "function_call"; let is_tool_finish =
finish_reason == "tool_calls" || finish_reason == "function_call";
if is_tool_finish { if is_tool_finish {
// ツール呼び出し終了 // ツール呼び出し終了
// Note: OpenAIはどのツールが終了したか明示しないため、 // Note: OpenAIはどのツールが終了したか明示しないため、
@ -156,11 +161,11 @@ mod tests {
fn test_parse_text_delta() { fn test_parse_text_delta() {
let scheme = OpenAIScheme::new(); let scheme = OpenAIScheme::new();
let data = r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#; let data = r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#;
let events = scheme.parse_event(data).unwrap().unwrap(); let events = scheme.parse_event(data).unwrap().unwrap();
// OpenAIはBlockStartを発行しないため、デルタのみ // OpenAIはBlockStartを発行しないため、デルタのみ
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
if let Event::BlockDelta(delta) = &events[0] { if let Event::BlockDelta(delta) = &events[0] {
assert_eq!(delta.index, 0); assert_eq!(delta.index, 0);
if let DeltaContent::Text(text) = &delta.delta { if let DeltaContent::Text(text) = &delta.delta {
@ -178,9 +183,9 @@ mod tests {
let scheme = OpenAIScheme::new(); let scheme = OpenAIScheme::new();
// Start of tool call // Start of tool call
let data_start = r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}"#; let data_start = r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}"#;
let events = scheme.parse_event(data_start).unwrap().unwrap(); let events = scheme.parse_event(data_start).unwrap().unwrap();
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
if let Event::BlockStart(start) = &events[0] { if let Event::BlockStart(start) = &events[0] {
assert_eq!(start.index, 0); assert_eq!(start.index, 0);
if let worker_types::BlockMetadata::ToolUse { id, name } = &start.metadata { if let worker_types::BlockMetadata::ToolUse { id, name } = &start.metadata {

View File

@ -120,12 +120,7 @@ impl OpenAIScheme {
}); });
} }
messages.extend( messages.extend(request.messages.iter().map(|m| self.convert_message(m)));
request
.messages
.iter()
.map(|m| self.convert_message(m))
);
let tools = request.tools.iter().map(|t| self.convert_tool(t)).collect(); let tools = request.tools.iter().map(|t| self.convert_tool(t)).collect();
@ -143,7 +138,9 @@ impl OpenAIScheme {
top_p: request.config.top_p, top_p: request.config.top_p,
stop: request.config.stop_sequences.clone(), stop: request.config.stop_sequences.clone(),
stream: true, stream: true,
stream_options: Some(StreamOptions { include_usage: true }), stream_options: Some(StreamOptions {
include_usage: true,
}),
messages, messages,
tools, tools,
tool_choice: None, // Default to auto if tools are present? Or let API decide (which is auto) tool_choice: None, // Default to auto if tools are present? Or let API decide (which is auto)
@ -224,14 +221,14 @@ impl OpenAIScheme {
name: None, name: None,
} }
} else { } else {
let content = if content_parts.is_empty() { let content = if content_parts.is_empty() {
None None
} else if content_parts.len() == 1 { } else if content_parts.len() == 1 {
// Simplify single text part to just Text content if preferred, or keep as Parts // Simplify single text part to just Text content if preferred, or keep as Parts
if let OpenAIContentPart::Text { text } = &content_parts[0] { if let OpenAIContentPart::Text { text } = &content_parts[0] {
Some(OpenAIContent::Text(text.clone())) Some(OpenAIContent::Text(text.clone()))
} else { } else {
Some(OpenAIContent::Parts(content_parts)) Some(OpenAIContent::Parts(content_parts))
} }
} else { } else {
Some(OpenAIContent::Parts(content_parts)) Some(OpenAIContent::Parts(content_parts))
@ -265,13 +262,10 @@ impl OpenAIScheme {
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_build_simple_request() { fn test_build_simple_request() {
let scheme = OpenAIScheme::new(); let scheme = OpenAIScheme::new();
let request = Request::new() let request = Request::new().system("System prompt").user("Hello");
.system("System prompt")
.user("Hello");
let body = scheme.build_request("gpt-4o", &request); let body = scheme.build_request("gpt-4o", &request);
@ -279,7 +273,7 @@ mod tests {
assert_eq!(body.messages.len(), 2); assert_eq!(body.messages.len(), 2);
assert_eq!(body.messages[0].role, "system"); assert_eq!(body.messages[0].role, "system");
assert_eq!(body.messages[1].role, "user"); assert_eq!(body.messages[1].role, "user");
// Check system content // Check system content
if let Some(OpenAIContent::Text(text)) = &body.messages[0].content { if let Some(OpenAIContent::Text(text)) = &body.messages[0].content {
assert_eq!(text, "System prompt"); assert_eq!(text, "System prompt");
@ -303,12 +297,10 @@ mod tests {
#[test] #[test]
fn test_build_request_legacy_max_tokens() { fn test_build_request_legacy_max_tokens() {
let scheme = OpenAIScheme::new().with_legacy_max_tokens(true); let scheme = OpenAIScheme::new().with_legacy_max_tokens(true);
let request = Request::new() let request = Request::new().user("Hello").max_tokens(100);
.user("Hello")
.max_tokens(100);
let body = scheme.build_request("llama3", &request); let body = scheme.build_request("llama3", &request);
// max_tokens should be set, max_completion_tokens should be None // max_tokens should be set, max_completion_tokens should be None
assert_eq!(body.max_tokens, Some(100)); assert_eq!(body.max_tokens, Some(100));
assert!(body.max_completion_tokens.is_none()); assert!(body.max_completion_tokens.is_none());
@ -317,12 +309,10 @@ mod tests {
#[test] #[test]
fn test_build_request_modern_max_tokens() { fn test_build_request_modern_max_tokens() {
let scheme = OpenAIScheme::new(); // Default matches modern (legacy=false) let scheme = OpenAIScheme::new(); // Default matches modern (legacy=false)
let request = Request::new() let request = Request::new().user("Hello").max_tokens(100);
.user("Hello")
.max_tokens(100);
let body = scheme.build_request("gpt-4o", &request); let body = scheme.build_request("gpt-4o", &request);
// max_completion_tokens should be set, max_tokens should be None // max_completion_tokens should be set, max_tokens should be None
assert_eq!(body.max_completion_tokens, Some(100)); assert_eq!(body.max_completion_tokens, Some(100));
assert!(body.max_tokens.is_none()); assert!(body.max_tokens.is_none());

View File

@ -502,13 +502,13 @@ impl Timeline {
fn handle_block_delta(&mut self, delta: &BlockDelta) { fn handle_block_delta(&mut self, delta: &BlockDelta) {
let block_type = delta.delta.block_type(); let block_type = delta.delta.block_type();
// OpenAIなどのプロバイダはBlockStartを送らない場合があるため、 // OpenAIなどのプロバイダはBlockStartを送らない場合があるため、
// Deltaが来たときにスコープがなければ暗黙的に開始する // Deltaが来たときにスコープがなければ暗黙的に開始する
if self.current_block.is_none() { if self.current_block.is_none() {
self.current_block = Some(block_type); self.current_block = Some(block_type);
} }
let handlers = self.get_block_handlers_mut(block_type); let handlers = self.get_block_handlers_mut(block_type);
for handler in handlers { for handler in handlers {
// スコープがなければ暗黙的に開始 // スコープがなければ暗黙的に開始

View File

@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex};
use futures::StreamExt; use futures::StreamExt;
use tracing::{debug, info, trace, warn}; use tracing::{debug, info, trace, warn};
use crate::Timeline;
use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition}; use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition};
use crate::subscriber_adapter::{ use crate::subscriber_adapter::{
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter, ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
@ -11,7 +12,6 @@ use crate::subscriber_adapter::{
}; };
use crate::text_block_collector::TextBlockCollector; use crate::text_block_collector::TextBlockCollector;
use crate::tool_call_collector::ToolCallCollector; use crate::tool_call_collector::ToolCallCollector;
use crate::Timeline;
use worker_types::{ use worker_types::{
ContentPart, ControlFlow, HookError, Message, MessageContent, Tool, ToolCall, ToolError, ContentPart, ControlFlow, HookError, Message, MessageContent, Tool, ToolCall, ToolError,
ToolResult, TurnResult, WorkerHook, WorkerSubscriber, ToolResult, TurnResult, WorkerHook, WorkerSubscriber,
@ -223,7 +223,7 @@ impl<C: LlmClient> Worker<C> {
pub async fn run(&mut self, messages: Vec<Message>) -> Result<Vec<Message>, WorkerError> { pub async fn run(&mut self, messages: Vec<Message>) -> Result<Vec<Message>, WorkerError> {
let mut context = messages; let mut context = messages;
let tool_definitions = self.build_tool_definitions(); let tool_definitions = self.build_tool_definitions();
info!( info!(
message_count = context.len(), message_count = context.len(),
tool_count = tool_definitions.len(), tool_count = tool_definitions.len(),
@ -442,10 +442,7 @@ impl<C: LlmClient> Worker<C> {
} }
/// Hooks: on_turn_end /// Hooks: on_turn_end
async fn run_on_turn_end_hooks( async fn run_on_turn_end_hooks(&self, messages: &[Message]) -> Result<TurnResult, WorkerError> {
&self,
messages: &[Message],
) -> Result<TurnResult, WorkerError> {
for hook in &self.hooks { for hook in &self.hooks {
let result = hook.on_turn_end(messages).await?; let result = hook.on_turn_end(messages).await?;
match result { match result {

View File

@ -3,13 +3,13 @@
use std::fs::File; use std::fs::File;
use std::io::{BufRead, BufReader}; use std::io::{BufRead, BufReader};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::pin::Pin; use std::pin::Pin;
use std::sync::{Arc, Mutex};
use async_trait::async_trait; use async_trait::async_trait;
use futures::Stream; use futures::Stream;
use worker::{Handler, TextBlockEvent, TextBlockKind, Timeline};
use worker::llm_client::{ClientError, LlmClient, Request}; use worker::llm_client::{ClientError, LlmClient, Request};
use worker::{Handler, TextBlockEvent, TextBlockKind, Timeline};
use worker_types::{BlockType, DeltaContent, Event}; use worker_types::{BlockType, DeltaContent, Event};
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
@ -51,11 +51,11 @@ impl LlmClient for MockLlmClient {
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> { ) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
let count = self.call_count.fetch_add(1, Ordering::SeqCst); let count = self.call_count.fetch_add(1, Ordering::SeqCst);
if count >= self.responses.len() { if count >= self.responses.len() {
return Err(ClientError::Api { return Err(ClientError::Api {
status: Some(500), status: Some(500),
code: Some("mock_error".to_string()), code: Some("mock_error".to_string()),
message: "No more mock responses".to_string(), message: "No more mock responses".to_string(),
}); });
} }
let events = self.responses[count].clone(); let events = self.responses[count].clone();
let stream = futures::stream::iter(events.into_iter().map(Ok)); let stream = futures::stream::iter(events.into_iter().map(Ok));
@ -135,7 +135,8 @@ pub fn assert_event_sequence(subdir: &str) {
} }
// Find a text-based fixture // Find a text-based fixture
let fixture_path = fixtures.iter() let fixture_path = fixtures
.iter()
.find(|p| p.to_string_lossy().contains("text")) .find(|p| p.to_string_lossy().contains("text"))
.unwrap_or(&fixtures[0]); .unwrap_or(&fixtures[0]);
@ -156,9 +157,9 @@ pub fn assert_event_sequence(subdir: &str) {
} }
} }
Event::BlockDelta(delta) => { Event::BlockDelta(delta) => {
if let DeltaContent::Text(_) = &delta.delta { if let DeltaContent::Text(_) = &delta.delta {
delta_found = true; delta_found = true;
} }
} }
Event::BlockStop(stop) => { Event::BlockStop(stop) => {
if stop.block_type == BlockType::Text { if stop.block_type == BlockType::Text {
@ -173,9 +174,9 @@ pub fn assert_event_sequence(subdir: &str) {
// Check for BlockStart (Warn only for OpenAI/Ollama as it might be missing for text) // Check for BlockStart (Warn only for OpenAI/Ollama as it might be missing for text)
if !start_found { if !start_found {
println!("Warning: No BlockStart found. This is common for OpenAI/Ollama text streams."); println!("Warning: No BlockStart found. This is common for OpenAI/Ollama text streams.");
// For Anthropic, strict start is usually expected, but to keep common logic simple we allow warning. // For Anthropic, strict start is usually expected, but to keep common logic simple we allow warning.
// If specific strictness is needed, we could add a `strict: bool` arg. // If specific strictness is needed, we could add a `strict: bool` arg.
} }
assert!(delta_found, "Should contain BlockDelta"); assert!(delta_found, "Should contain BlockDelta");
@ -184,7 +185,9 @@ pub fn assert_event_sequence(subdir: &str) {
assert!(stop_found, "Should contain BlockStop for Text block"); assert!(stop_found, "Should contain BlockStop for Text block");
} else { } else {
if !stop_found { if !stop_found {
println!(" [Type: ToolUse] BlockStop detection skipped (not explicitly emitted by scheme)"); println!(
" [Type: ToolUse] BlockStop detection skipped (not explicitly emitted by scheme)"
);
} }
} }
} }
@ -200,13 +203,23 @@ pub fn assert_usage_tokens(subdir: &str) {
let events = load_events_from_fixture(&fixture); let events = load_events_from_fixture(&fixture);
let usage_events: Vec<_> = events let usage_events: Vec<_> = events
.iter() .iter()
.filter_map(|e| if let Event::Usage(u) = e { Some(u) } else { None }) .filter_map(|e| {
if let Event::Usage(u) = e {
Some(u)
} else {
None
}
})
.collect(); .collect();
if !usage_events.is_empty() { if !usage_events.is_empty() {
let last_usage = usage_events.last().unwrap(); let last_usage = usage_events.last().unwrap();
if last_usage.input_tokens.is_some() || last_usage.output_tokens.is_some() { if last_usage.input_tokens.is_some() || last_usage.output_tokens.is_some() {
println!(" Fixture {:?} Usage: {:?}", fixture.file_name(), last_usage); println!(
" Fixture {:?} Usage: {:?}",
fixture.file_name(),
last_usage
);
return; // Found valid usage return; // Found valid usage
} }
} }
@ -221,7 +234,8 @@ pub fn assert_timeline_integration(subdir: &str) {
return; return;
} }
let fixture_path = fixtures.iter() let fixture_path = fixtures
.iter()
.find(|p| p.to_string_lossy().contains("text")) .find(|p| p.to_string_lossy().contains("text"))
.unwrap_or(&fixtures[0]); .unwrap_or(&fixtures[0]);

View File

@ -2,13 +2,16 @@
//! //!
//! Workerが複数のツールを並列に実行することを確認する。 //! Workerが複数のツールを並列に実行することを確認する。
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use async_trait::async_trait; use async_trait::async_trait;
use worker::Worker; use worker::Worker;
use worker_types::{Event, Message, ResponseStatus, StatusEvent, Tool, ToolError, ToolResult, ToolCall, ControlFlow, HookError, WorkerHook}; use worker_types::{
ControlFlow, Event, HookError, Message, ResponseStatus, StatusEvent, Tool, ToolCall, ToolError,
ToolResult, WorkerHook,
};
mod common; mod common;
use common::MockLlmClient; use common::MockLlmClient;
@ -105,8 +108,6 @@ async fn test_parallel_tool_execution() {
worker.register_tool(tool2); worker.register_tool(tool2);
worker.register_tool(tool3); worker.register_tool(tool3);
let messages = vec![Message::user("Run all tools")]; let messages = vec![Message::user("Run all tools")];
let start = Instant::now(); let start = Instant::now();
@ -161,7 +162,10 @@ async fn test_before_tool_call_skip() {
#[async_trait] #[async_trait]
impl WorkerHook for BlockingHook { impl WorkerHook for BlockingHook {
async fn before_tool_call(&self, tool_call: &mut ToolCall) -> Result<ControlFlow, HookError> { async fn before_tool_call(
&self,
tool_call: &mut ToolCall,
) -> Result<ControlFlow, HookError> {
if tool_call.name == "blocked_tool" { if tool_call.name == "blocked_tool" {
Ok(ControlFlow::Skip) Ok(ControlFlow::Skip)
} else { } else {
@ -176,8 +180,16 @@ async fn test_before_tool_call_skip() {
let _result = worker.run(messages).await; let _result = worker.run(messages).await;
// allowed_tool は呼び出されるが、blocked_tool は呼び出されない // allowed_tool は呼び出されるが、blocked_tool は呼び出されない
assert_eq!(allowed_clone.call_count(), 1, "Allowed tool should be called"); assert_eq!(
assert_eq!(blocked_clone.call_count(), 0, "Blocked tool should not be called"); allowed_clone.call_count(),
1,
"Allowed tool should be called"
);
assert_eq!(
blocked_clone.call_count(),
0,
"Blocked tool should not be called"
);
} }
/// Hook: after_tool_call で結果が改変されることを確認 /// Hook: after_tool_call で結果が改変されることを確認
@ -212,9 +224,15 @@ async fn test_after_tool_call_modification() {
#[async_trait] #[async_trait]
impl Tool for SimpleTool { impl Tool for SimpleTool {
fn name(&self) -> &str { "test_tool" } fn name(&self) -> &str {
fn description(&self) -> &str { "Test" } "test_tool"
fn input_schema(&self) -> serde_json::Value { serde_json::json!({}) } }
fn description(&self) -> &str {
"Test"
}
fn input_schema(&self) -> serde_json::Value {
serde_json::json!({})
}
async fn execute(&self, _: &str) -> Result<String, ToolError> { async fn execute(&self, _: &str) -> Result<String, ToolError> {
Ok("Original Result".to_string()) Ok("Original Result".to_string())
} }
@ -229,7 +247,10 @@ async fn test_after_tool_call_modification() {
#[async_trait] #[async_trait]
impl WorkerHook for ModifyingHook { impl WorkerHook for ModifyingHook {
async fn after_tool_call(&self, tool_result: &mut ToolResult) -> Result<ControlFlow, HookError> { async fn after_tool_call(
&self,
tool_result: &mut ToolResult,
) -> Result<ControlFlow, HookError> {
tool_result.content = format!("[Modified] {}", tool_result.content); tool_result.content = format!("[Modified] {}", tool_result.content);
*self.modified_content.lock().unwrap() = Some(tool_result.content.clone()); *self.modified_content.lock().unwrap() = Some(tool_result.content.clone());
Ok(ControlFlow::Continue) Ok(ControlFlow::Continue)
@ -237,7 +258,9 @@ async fn test_after_tool_call_modification() {
} }
let modified_content = Arc::new(std::sync::Mutex::new(None)); let modified_content = Arc::new(std::sync::Mutex::new(None));
worker.add_hook(ModifyingHook { modified_content: modified_content.clone() }); worker.add_hook(ModifyingHook {
modified_content: modified_content.clone(),
});
let messages = vec![Message::user("Test modification")]; let messages = vec![Message::user("Test modification")];
let result = worker.run(messages).await; let result = worker.run(messages).await;

View File

@ -2,8 +2,8 @@
//! //!
//! `#[tool_registry]` と `#[tool]` マクロの動作を確認する。 //! `#[tool_registry]` と `#[tool]` マクロの動作を確認する。
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
// マクロ展開に必要なインポート // マクロ展開に必要なインポート
use schemars; use schemars;
@ -59,12 +59,19 @@ async fn test_basic_tool_generation() {
// 説明の確認docコメントから取得 // 説明の確認docコメントから取得
let desc = greet_tool.description(); let desc = greet_tool.description();
assert!(desc.contains("メッセージに挨拶を追加する"), "Description should contain doc comment: {}", desc); assert!(
desc.contains("メッセージに挨拶を追加する"),
"Description should contain doc comment: {}",
desc
);
// スキーマの確認 // スキーマの確認
let schema = greet_tool.input_schema(); let schema = greet_tool.input_schema();
println!("Schema: {}", serde_json::to_string_pretty(&schema).unwrap()); println!("Schema: {}", serde_json::to_string_pretty(&schema).unwrap());
assert!(schema.get("properties").is_some(), "Schema should have properties"); assert!(
schema.get("properties").is_some(),
"Schema should have properties"
);
// 実行テスト // 実行テスト
let result = greet_tool.execute(r#"{"message": "World"}"#).await; let result = greet_tool.execute(r#"{"message": "World"}"#).await;
@ -104,7 +111,11 @@ async fn test_no_arguments() {
let result = get_prefix_tool.execute(r#"{}"#).await; let result = get_prefix_tool.execute(r#"{}"#).await;
assert!(result.is_ok()); assert!(result.is_ok());
let output = result.unwrap(); let output = result.unwrap();
assert!(output.contains("TestPrefix"), "Should contain prefix: {}", output); assert!(
output.contains("TestPrefix"),
"Should contain prefix: {}",
output
);
} }
#[tokio::test] #[tokio::test]
@ -169,7 +180,11 @@ async fn test_result_return_type_error() {
assert!(result.is_err(), "Should fail for negative value"); assert!(result.is_err(), "Should fail for negative value");
let err = result.unwrap_err(); let err = result.unwrap_err();
assert!(err.to_string().contains("positive"), "Error should mention positive: {}", err); assert!(
err.to_string().contains("positive"),
"Error should mention positive: {}",
err
);
} }
// ============================================================================= // =============================================================================

View File

@ -6,8 +6,8 @@
mod common; mod common;
use std::path::Path; use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use async_trait::async_trait; use async_trait::async_trait;
use common::MockLlmClient; use common::MockLlmClient;
@ -67,9 +67,7 @@ impl Tool for MockWeatherTool {
let input: serde_json::Value = serde_json::from_str(input_json) let input: serde_json::Value = serde_json::from_str(input_json)
.map_err(|e| ToolError::InvalidArgument(e.to_string()))?; .map_err(|e| ToolError::InvalidArgument(e.to_string()))?;
let city = input["city"] let city = input["city"].as_str().unwrap_or("Unknown");
.as_str()
.unwrap_or("Unknown");
// モックのレスポンスを返す // モックのレスポンスを返す
Ok(format!("Weather in {}: Sunny, 22°C", city)) Ok(format!("Weather in {}: Sunny, 22°C", city))
@ -163,8 +161,6 @@ async fn test_worker_tool_call() {
let tool_for_check = weather_tool.clone(); let tool_for_check = weather_tool.clone();
worker.register_tool(weather_tool); worker.register_tool(weather_tool);
// メッセージを送信 // メッセージを送信
let messages = vec![worker_types::Message::user("What's the weather in Tokyo?")]; let messages = vec![worker_types::Message::user("What's the weather in Tokyo?")];
let _result = worker.run(messages).await; let _result = worker.run(messages).await;
@ -212,8 +208,8 @@ async fn test_worker_with_programmatic_events() {
/// id, name, inputJSONを正しく抽出できることを検証する。 /// id, name, inputJSONを正しく抽出できることを検証する。
#[tokio::test] #[tokio::test]
async fn test_tool_call_collector_integration() { async fn test_tool_call_collector_integration() {
use worker::ToolCallCollector;
use worker::Timeline; use worker::Timeline;
use worker::ToolCallCollector;
use worker_types::Event; use worker_types::Event;
// ToolUseブロックを含むイベントシーケンス // ToolUseブロックを含むイベントシーケンス