fmt: cargo fmt
This commit is contained in:
parent
bb73dc6a45
commit
1e126c1698
|
|
@ -6,7 +6,7 @@
|
||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
use quote::{format_ident, quote};
|
use quote::{format_ident, quote};
|
||||||
use syn::{
|
use syn::{
|
||||||
parse_macro_input, Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type,
|
Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, parse_macro_input,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// `impl` ブロックに付与し、内部の `#[tool]` 属性がついたメソッドからツールを生成するマクロ。
|
/// `impl` ブロックに付与し、内部の `#[tool]` 属性がついたメソッドからツールを生成するマクロ。
|
||||||
|
|
|
||||||
|
|
@ -127,7 +127,10 @@ pub trait WorkerHook: Send + Sync {
|
||||||
/// ツール実行後
|
/// ツール実行後
|
||||||
///
|
///
|
||||||
/// 結果を書き換えたり、隠蔽したりできる。
|
/// 結果を書き換えたり、隠蔽したりできる。
|
||||||
async fn after_tool_call(&self, _tool_result: &mut ToolResult) -> Result<ControlFlow, HookError> {
|
async fn after_tool_call(
|
||||||
|
&self,
|
||||||
|
_tool_result: &mut ToolResult,
|
||||||
|
) -> Result<ControlFlow, HookError> {
|
||||||
Ok(ControlFlow::Continue)
|
Ok(ControlFlow::Continue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,10 @@ pub enum ContentPart {
|
||||||
},
|
},
|
||||||
/// ツール結果
|
/// ツール結果
|
||||||
#[serde(rename = "tool_result")]
|
#[serde(rename = "tool_result")]
|
||||||
ToolResult { tool_use_id: String, content: String },
|
ToolResult {
|
||||||
|
tool_use_id: String,
|
||||||
|
content: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Message {
|
impl Message {
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,7 @@
|
||||||
//! Timeline層のHandler機構の薄いラッパーとして設計され、
|
//! Timeline層のHandler機構の薄いラッパーとして設計され、
|
||||||
//! UIへのストリーミング表示やリアルタイムフィードバックを可能にする。
|
//! UIへのストリーミング表示やリアルタイムフィードバックを可能にする。
|
||||||
|
|
||||||
use crate::{
|
use crate::{ErrorEvent, StatusEvent, TextBlockEvent, ToolCall, ToolUseBlockEvent, UsageEvent};
|
||||||
ErrorEvent, StatusEvent, TextBlockEvent, ToolCall, ToolUseBlockEvent, UsageEvent,
|
|
||||||
};
|
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// WorkerSubscriber Trait
|
// WorkerSubscriber Trait
|
||||||
|
|
@ -74,7 +72,11 @@ pub trait WorkerSubscriber: Send {
|
||||||
///
|
///
|
||||||
/// Start/InputJsonDelta/Stopのライフサイクルを持つ。
|
/// Start/InputJsonDelta/Stopのライフサイクルを持つ。
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
fn on_tool_use_block(&mut self, scope: &mut Self::ToolUseBlockScope, event: &ToolUseBlockEvent) {
|
fn on_tool_use_block(
|
||||||
|
&mut self,
|
||||||
|
scope: &mut Self::ToolUseBlockScope,
|
||||||
|
event: &ToolUseBlockEvent,
|
||||||
|
) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// =========================================================================
|
// =========================================================================
|
||||||
|
|
|
||||||
|
|
@ -111,8 +111,8 @@ impl Handler<UsageKind> for UsageTracker {
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// APIキーを環境変数から取得
|
// APIキーを環境変数から取得
|
||||||
let api_key = std::env::var("GEMINI_API_KEY")
|
let api_key =
|
||||||
.expect("GEMINI_API_KEY environment variable must be set");
|
std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY environment variable must be set");
|
||||||
|
|
||||||
println!("=== Gemini LLM Client + Timeline Integration Example ===\n");
|
println!("=== Gemini LLM Client + Timeline Integration Example ===\n");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,6 @@
|
||||||
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- --all
|
//! ANTHROPIC_API_KEY=your-key cargo run --example record_test_fixtures -- --all
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
mod recorder;
|
mod recorder;
|
||||||
mod scenarios;
|
mod scenarios;
|
||||||
|
|
||||||
|
|
@ -82,7 +79,8 @@ async fn run_scenario_with_openai(
|
||||||
subdir: &str,
|
subdir: &str,
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY environment variable must be set");
|
let api_key =
|
||||||
|
std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY environment variable must be set");
|
||||||
let model = model.as_deref().unwrap_or("gpt-4o");
|
let model = model.as_deref().unwrap_or("gpt-4o");
|
||||||
let client = OpenAIClient::new(&api_key, model);
|
let client = OpenAIClient::new(&api_key, model);
|
||||||
|
|
||||||
|
|
@ -125,8 +123,8 @@ async fn run_scenario_with_gemini(
|
||||||
subdir: &str,
|
subdir: &str,
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let api_key = std::env::var("GEMINI_API_KEY")
|
let api_key =
|
||||||
.expect("GEMINI_API_KEY environment variable must be set");
|
std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY environment variable must be set");
|
||||||
let model = model.as_deref().unwrap_or("gemini-2.0-flash");
|
let model = model.as_deref().unwrap_or("gemini-2.0-flash");
|
||||||
let client = GeminiClient::new(&api_key, model);
|
let client = GeminiClient::new(&api_key, model);
|
||||||
|
|
||||||
|
|
@ -142,9 +140,6 @@ async fn run_scenario_with_gemini(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
dotenv::dotenv().ok();
|
dotenv::dotenv().ok();
|
||||||
|
|
@ -173,13 +168,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if found.is_empty() {
|
if found.is_empty() {
|
||||||
eprintln!("Error: Unknown scenario '{}'", scenario_name);
|
eprintln!("Error: Unknown scenario '{}'", scenario_name);
|
||||||
// Verify correct name by listing
|
// Verify correct name by listing
|
||||||
println!("Available scenarios:");
|
println!("Available scenarios:");
|
||||||
for s in scenarios::scenarios() {
|
for s in scenarios::scenarios() {
|
||||||
println!(" {}", s.output_name);
|
println!(" {}", s.output_name);
|
||||||
}
|
}
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
found
|
found
|
||||||
};
|
};
|
||||||
|
|
@ -201,12 +196,20 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// シナリオのフィルタリングは main.rs のロジックで実行済み
|
// シナリオのフィルタリングは main.rs のロジックで実行済み
|
||||||
// ここでは単純なループで実行
|
// ここでは単純なループで実行
|
||||||
for scenario in scenarios_to_run {
|
for scenario in scenarios_to_run {
|
||||||
match args.client {
|
match args.client {
|
||||||
ClientType::Anthropic => run_scenario_with_anthropic(&scenario, subdir, args.model.clone()).await?,
|
ClientType::Anthropic => {
|
||||||
ClientType::Gemini => run_scenario_with_gemini(&scenario, subdir, args.model.clone()).await?,
|
run_scenario_with_anthropic(&scenario, subdir, args.model.clone()).await?
|
||||||
ClientType::Openai => run_scenario_with_openai(&scenario, subdir, args.model.clone()).await?,
|
}
|
||||||
ClientType::Ollama => run_scenario_with_ollama(&scenario, subdir, args.model.clone()).await?,
|
ClientType::Gemini => {
|
||||||
}
|
run_scenario_with_gemini(&scenario, subdir, args.model.clone()).await?
|
||||||
|
}
|
||||||
|
ClientType::Openai => {
|
||||||
|
run_scenario_with_openai(&scenario, subdir, args.model.clone()).await?
|
||||||
|
}
|
||||||
|
ClientType::Ollama => {
|
||||||
|
run_scenario_with_ollama(&scenario, subdir, args.model.clone()).await?
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("\n✅ Done!");
|
println!("\n✅ Done!");
|
||||||
|
|
|
||||||
|
|
@ -38,14 +38,14 @@ use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
use worker::{
|
use worker::{
|
||||||
|
Handler, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind, Worker,
|
||||||
llm_client::{
|
llm_client::{
|
||||||
|
LlmClient,
|
||||||
providers::{
|
providers::{
|
||||||
anthropic::AnthropicClient, gemini::GeminiClient, ollama::OllamaClient,
|
anthropic::AnthropicClient, gemini::GeminiClient, ollama::OllamaClient,
|
||||||
openai::OpenAIClient,
|
openai::OpenAIClient,
|
||||||
},
|
},
|
||||||
LlmClient,
|
|
||||||
},
|
},
|
||||||
Handler, TextBlockEvent, TextBlockKind, ToolUseBlockEvent, ToolUseBlockKind, Worker,
|
|
||||||
};
|
};
|
||||||
use worker_macros::tool_registry;
|
use worker_macros::tool_registry;
|
||||||
use worker_types::Message;
|
use worker_types::Message;
|
||||||
|
|
@ -310,8 +310,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// ロギング初期化
|
// ロギング初期化
|
||||||
// RUST_LOG=debug cargo run --example worker_cli ... で詳細ログ表示
|
// RUST_LOG=debug cargo run --example worker_cli ... で詳細ログ表示
|
||||||
// デフォルトは warn レベル、RUST_LOG 環境変数で上書き可能
|
// デフォルトは warn レベル、RUST_LOG 環境変数で上書き可能
|
||||||
let filter = EnvFilter::try_from_default_env()
|
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("warn"));
|
||||||
.unwrap_or_else(|_| EnvFilter::new("warn"));
|
|
||||||
|
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
.with_env_filter(filter)
|
.with_env_filter(filter)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ use std::pin::Pin;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use eventsource_stream::Eventsource;
|
use eventsource_stream::Eventsource;
|
||||||
use futures::{future::ready, Stream, StreamExt, TryStreamExt};
|
use futures::{Stream, StreamExt, TryStreamExt, future::ready};
|
||||||
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
|
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
|
||||||
use worker_types::Event;
|
use worker_types::Event;
|
||||||
|
|
||||||
|
|
@ -178,7 +178,6 @@ impl LlmClient for AnthropicClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,7 @@ use futures::Stream;
|
||||||
use worker_types::Event;
|
use worker_types::Event;
|
||||||
|
|
||||||
use crate::llm_client::{
|
use crate::llm_client::{
|
||||||
ClientError, LlmClient, Request,
|
ClientError, LlmClient, Request, providers::openai::OpenAIClient, scheme::openai::OpenAIScheme,
|
||||||
providers::openai::OpenAIClient,
|
|
||||||
scheme::openai::OpenAIScheme,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Ollama クライアント
|
/// Ollama クライアント
|
||||||
|
|
|
||||||
|
|
@ -66,16 +66,16 @@ impl OpenAIClient {
|
||||||
// For providers like Ollama, API key might be empty/dummy.
|
// For providers like Ollama, API key might be empty/dummy.
|
||||||
// But typical OpenAI requires it.
|
// But typical OpenAI requires it.
|
||||||
// We'll allow empty if user intends it, but usually it's checked.
|
// We'll allow empty if user intends it, but usually it's checked.
|
||||||
HeaderValue::from_static("")
|
HeaderValue::from_static("")
|
||||||
} else {
|
} else {
|
||||||
let mut val = HeaderValue::from_str(&format!("Bearer {}", self.api_key))
|
let mut val = HeaderValue::from_str(&format!("Bearer {}", self.api_key))
|
||||||
.map_err(|e| ClientError::Config(format!("Invalid API key: {}", e)))?;
|
.map_err(|e| ClientError::Config(format!("Invalid API key: {}", e)))?;
|
||||||
val.set_sensitive(true);
|
val.set_sensitive(true);
|
||||||
val
|
val
|
||||||
};
|
};
|
||||||
|
|
||||||
if !api_key_val.is_empty() {
|
if !api_key_val.is_empty() {
|
||||||
headers.insert("Authorization", api_key_val);
|
headers.insert("Authorization", api_key_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(headers)
|
Ok(headers)
|
||||||
|
|
@ -105,11 +105,11 @@ impl LlmClient for OpenAIClient {
|
||||||
// Ideally `base_url` should be the root passed to `new`.
|
// Ideally `base_url` should be the root passed to `new`.
|
||||||
|
|
||||||
let url = if self.base_url.ends_with("/v1") {
|
let url = if self.base_url.ends_with("/v1") {
|
||||||
format!("{}/chat/completions", self.base_url)
|
format!("{}/chat/completions", self.base_url)
|
||||||
} else if self.base_url.ends_with("/") {
|
} else if self.base_url.ends_with("/") {
|
||||||
format!("{}v1/chat/completions", self.base_url)
|
format!("{}v1/chat/completions", self.base_url)
|
||||||
} else {
|
} else {
|
||||||
format!("{}/v1/chat/completions", self.base_url)
|
format!("{}/v1/chat/completions", self.base_url)
|
||||||
};
|
};
|
||||||
|
|
||||||
let headers = self.build_headers()?;
|
let headers = self.build_headers()?;
|
||||||
|
|
@ -159,40 +159,41 @@ impl LlmClient for OpenAIClient {
|
||||||
.map_err(|e| std::io::Error::other(e));
|
.map_err(|e| std::io::Error::other(e));
|
||||||
let event_stream = byte_stream.eventsource();
|
let event_stream = byte_stream.eventsource();
|
||||||
|
|
||||||
let stream = event_stream.map(move |result| {
|
let stream = event_stream
|
||||||
match result {
|
.map(move |result| {
|
||||||
Ok(event) => {
|
match result {
|
||||||
// SSEイベントをパース
|
Ok(event) => {
|
||||||
// OpenAI stream events are "data: {...}"
|
// SSEイベントをパース
|
||||||
// event.event is usually "message" (default) or empty.
|
// OpenAI stream events are "data: {...}"
|
||||||
// parse_event takes data string.
|
// event.event is usually "message" (default) or empty.
|
||||||
|
// parse_event takes data string.
|
||||||
|
|
||||||
if event.data == "[DONE]" {
|
if event.data == "[DONE]" {
|
||||||
// End of stream handled inside parse_event usually returning None
|
// End of stream handled inside parse_event usually returning None
|
||||||
Ok(None)
|
Ok(None)
|
||||||
} else {
|
} else {
|
||||||
match scheme.parse_event(&event.data) {
|
match scheme.parse_event(&event.data) {
|
||||||
Ok(Some(events)) => Ok(Some(events)),
|
Ok(Some(events)) => Ok(Some(events)),
|
||||||
Ok(None) => Ok(None),
|
Ok(None) => Ok(None),
|
||||||
Err(e) => Err(e),
|
Err(e) => Err(e),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Err(e) => Err(ClientError::Sse(e.to_string())),
|
||||||
}
|
}
|
||||||
Err(e) => Err(ClientError::Sse(e.to_string())),
|
})
|
||||||
}
|
// flatten Option<Vec<Event>> stream to Stream<Event>
|
||||||
})
|
// map returns Result<Option<Vec<Event>>, Error>
|
||||||
// flatten Option<Vec<Event>> stream to Stream<Event>
|
// We want Stream<Item = Result<Event, Error>>
|
||||||
// map returns Result<Option<Vec<Event>>, Error>
|
.map(|res| {
|
||||||
// We want Stream<Item = Result<Event, Error>>
|
let s: Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>> = match res {
|
||||||
.map(|res| {
|
Ok(Some(events)) => Box::pin(futures::stream::iter(events.into_iter().map(Ok))),
|
||||||
let s: Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>> = match res {
|
Ok(None) => Box::pin(futures::stream::empty()),
|
||||||
Ok(Some(events)) => Box::pin(futures::stream::iter(events.into_iter().map(Ok))),
|
Err(e) => Box::pin(futures::stream::once(async move { Err(e) })),
|
||||||
Ok(None) => Box::pin(futures::stream::empty()),
|
};
|
||||||
Err(e) => Box::pin(futures::stream::once(async move { Err(e) })),
|
s
|
||||||
};
|
})
|
||||||
s
|
.flatten();
|
||||||
})
|
|
||||||
.flatten();
|
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
Ok(Box::pin(stream))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -127,13 +127,12 @@ impl GeminiScheme {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
let response: GenerateContentResponse = serde_json::from_str(data).map_err(|e| {
|
let response: GenerateContentResponse =
|
||||||
ClientError::Api {
|
serde_json::from_str(data).map_err(|e| ClientError::Api {
|
||||||
status: None,
|
status: None,
|
||||||
code: Some("parse_error".to_string()),
|
code: Some("parse_error".to_string()),
|
||||||
message: format!("Failed to parse Gemini SSE data: {} -> {}", e, data),
|
message: format!("Failed to parse Gemini SSE data: {} -> {}", e, data),
|
||||||
}
|
})?;
|
||||||
})?;
|
|
||||||
|
|
||||||
let mut events = Vec::new();
|
let mut events = Vec::new();
|
||||||
|
|
||||||
|
|
@ -155,10 +154,7 @@ impl GeminiScheme {
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
// Geminiは明示的なBlockStartを送らないため、
|
// Geminiは明示的なBlockStartを送らないため、
|
||||||
// TextDeltaを直接送る(Timelineが暗黙的に開始を処理)
|
// TextDeltaを直接送る(Timelineが暗黙的に開始を処理)
|
||||||
events.push(Event::text_delta(
|
events.push(Event::text_delta(part_index, text.clone()));
|
||||||
part_index,
|
|
||||||
text.clone(),
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -240,7 +236,8 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parse_text_response() {
|
fn test_parse_text_response() {
|
||||||
let scheme = GeminiScheme::new();
|
let scheme = GeminiScheme::new();
|
||||||
let data = r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"index":0}]}"#;
|
let data =
|
||||||
|
r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"index":0}]}"#;
|
||||||
|
|
||||||
let events = scheme.parse_event(data).unwrap().unwrap();
|
let events = scheme.parse_event(data).unwrap().unwrap();
|
||||||
assert_eq!(events.len(), 1);
|
assert_eq!(events.len(), 1);
|
||||||
|
|
|
||||||
|
|
@ -46,9 +46,7 @@ pub(crate) struct GeminiContent {
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub(crate) enum GeminiPart {
|
pub(crate) enum GeminiPart {
|
||||||
/// テキストパーツ
|
/// テキストパーツ
|
||||||
Text {
|
Text { text: String },
|
||||||
text: String,
|
|
||||||
},
|
|
||||||
/// 関数呼び出しパーツ
|
/// 関数呼び出しパーツ
|
||||||
FunctionCall {
|
FunctionCall {
|
||||||
#[serde(rename = "functionCall")]
|
#[serde(rename = "functionCall")]
|
||||||
|
|
@ -160,11 +158,7 @@ impl GeminiScheme {
|
||||||
vec![]
|
vec![]
|
||||||
} else {
|
} else {
|
||||||
vec![GeminiTool {
|
vec![GeminiTool {
|
||||||
function_declarations: request
|
function_declarations: request.tools.iter().map(|t| self.convert_tool(t)).collect(),
|
||||||
.tools
|
|
||||||
.iter()
|
|
||||||
.map(|t| self.convert_tool(t))
|
|
||||||
.collect(),
|
|
||||||
}]
|
}]
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -224,34 +218,30 @@ impl GeminiScheme {
|
||||||
},
|
},
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
MessageContent::Parts(parts) => {
|
MessageContent::Parts(parts) => parts
|
||||||
parts
|
.iter()
|
||||||
.iter()
|
.map(|p| match p {
|
||||||
.map(|p| match p {
|
ContentPart::Text { text } => GeminiPart::Text { text: text.clone() },
|
||||||
ContentPart::Text { text } => GeminiPart::Text { text: text.clone() },
|
ContentPart::ToolUse { id: _, name, input } => GeminiPart::FunctionCall {
|
||||||
ContentPart::ToolUse { id: _, name, input } => {
|
function_call: GeminiFunctionCall {
|
||||||
GeminiPart::FunctionCall {
|
name: name.clone(),
|
||||||
function_call: GeminiFunctionCall {
|
args: input.clone(),
|
||||||
name: name.clone(),
|
},
|
||||||
args: input.clone(),
|
},
|
||||||
},
|
ContentPart::ToolResult {
|
||||||
}
|
tool_use_id,
|
||||||
}
|
content,
|
||||||
ContentPart::ToolResult {
|
} => GeminiPart::FunctionResponse {
|
||||||
tool_use_id,
|
function_response: GeminiFunctionResponse {
|
||||||
content,
|
name: tool_use_id.clone(),
|
||||||
} => GeminiPart::FunctionResponse {
|
response: GeminiFunctionResponseContent {
|
||||||
function_response: GeminiFunctionResponse {
|
|
||||||
name: tool_use_id.clone(),
|
name: tool_use_id.clone(),
|
||||||
response: GeminiFunctionResponseContent {
|
content: serde_json::Value::String(content.clone()),
|
||||||
name: tool_use_id.clone(),
|
|
||||||
content: serde_json::Value::String(content.clone()),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
},
|
||||||
.collect()
|
})
|
||||||
}
|
.collect(),
|
||||||
};
|
};
|
||||||
|
|
||||||
GeminiContent {
|
GeminiContent {
|
||||||
|
|
@ -306,16 +296,17 @@ mod tests {
|
||||||
|
|
||||||
assert_eq!(gemini_req.tools.len(), 1);
|
assert_eq!(gemini_req.tools.len(), 1);
|
||||||
assert_eq!(gemini_req.tools[0].function_declarations.len(), 1);
|
assert_eq!(gemini_req.tools[0].function_declarations.len(), 1);
|
||||||
assert_eq!(gemini_req.tools[0].function_declarations[0].name, "get_weather");
|
assert_eq!(
|
||||||
|
gemini_req.tools[0].function_declarations[0].name,
|
||||||
|
"get_weather"
|
||||||
|
);
|
||||||
assert!(gemini_req.tool_config.is_some());
|
assert!(gemini_req.tool_config.is_some());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_assistant_role_is_model() {
|
fn test_assistant_role_is_model() {
|
||||||
let scheme = GeminiScheme::new();
|
let scheme = GeminiScheme::new();
|
||||||
let request = Request::new()
|
let request = Request::new().user("Hello").assistant("Hi there!");
|
||||||
.user("Hello")
|
|
||||||
.assistant("Hi there!");
|
|
||||||
|
|
||||||
let gemini_req = scheme.build_request(&request);
|
let gemini_req = scheme.build_request(&request);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,8 +69,8 @@ impl OpenAIScheme {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
let chunk: ChatCompletionChunk = serde_json::from_str(data)
|
let chunk: ChatCompletionChunk =
|
||||||
.map_err(|e| ClientError::Api {
|
serde_json::from_str(data).map_err(|e| ClientError::Api {
|
||||||
status: None,
|
status: None,
|
||||||
code: Some("parse_error".to_string()),
|
code: Some("parse_error".to_string()),
|
||||||
message: format!("Failed to parse SSE data: {} -> {}", e, data),
|
message: format!("Failed to parse SSE data: {} -> {}", e, data),
|
||||||
|
|
@ -102,7 +102,11 @@ impl OpenAIScheme {
|
||||||
for tool_call in tool_calls {
|
for tool_call in tool_calls {
|
||||||
// Start of tool call (has ID)
|
// Start of tool call (has ID)
|
||||||
if let Some(id) = tool_call.id {
|
if let Some(id) = tool_call.id {
|
||||||
let name = tool_call.function.as_ref().and_then(|f| f.name.clone()).unwrap_or_default();
|
let name = tool_call
|
||||||
|
.function
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|f| f.name.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
events.push(Event::tool_use_start(tool_call.index, id, name));
|
events.push(Event::tool_use_start(tool_call.index, id, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -126,7 +130,8 @@ impl OpenAIScheme {
|
||||||
_ => Some(StopReason::EndTurn),
|
_ => Some(StopReason::EndTurn),
|
||||||
};
|
};
|
||||||
|
|
||||||
let is_tool_finish = finish_reason == "tool_calls" || finish_reason == "function_call";
|
let is_tool_finish =
|
||||||
|
finish_reason == "tool_calls" || finish_reason == "function_call";
|
||||||
|
|
||||||
if is_tool_finish {
|
if is_tool_finish {
|
||||||
// ツール呼び出し終了
|
// ツール呼び出し終了
|
||||||
|
|
|
||||||
|
|
@ -120,12 +120,7 @@ impl OpenAIScheme {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.extend(
|
messages.extend(request.messages.iter().map(|m| self.convert_message(m)));
|
||||||
request
|
|
||||||
.messages
|
|
||||||
.iter()
|
|
||||||
.map(|m| self.convert_message(m))
|
|
||||||
);
|
|
||||||
|
|
||||||
let tools = request.tools.iter().map(|t| self.convert_tool(t)).collect();
|
let tools = request.tools.iter().map(|t| self.convert_tool(t)).collect();
|
||||||
|
|
||||||
|
|
@ -143,7 +138,9 @@ impl OpenAIScheme {
|
||||||
top_p: request.config.top_p,
|
top_p: request.config.top_p,
|
||||||
stop: request.config.stop_sequences.clone(),
|
stop: request.config.stop_sequences.clone(),
|
||||||
stream: true,
|
stream: true,
|
||||||
stream_options: Some(StreamOptions { include_usage: true }),
|
stream_options: Some(StreamOptions {
|
||||||
|
include_usage: true,
|
||||||
|
}),
|
||||||
messages,
|
messages,
|
||||||
tools,
|
tools,
|
||||||
tool_choice: None, // Default to auto if tools are present? Or let API decide (which is auto)
|
tool_choice: None, // Default to auto if tools are present? Or let API decide (which is auto)
|
||||||
|
|
@ -224,14 +221,14 @@ impl OpenAIScheme {
|
||||||
name: None,
|
name: None,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let content = if content_parts.is_empty() {
|
let content = if content_parts.is_empty() {
|
||||||
None
|
None
|
||||||
} else if content_parts.len() == 1 {
|
} else if content_parts.len() == 1 {
|
||||||
// Simplify single text part to just Text content if preferred, or keep as Parts
|
// Simplify single text part to just Text content if preferred, or keep as Parts
|
||||||
if let OpenAIContentPart::Text { text } = &content_parts[0] {
|
if let OpenAIContentPart::Text { text } = &content_parts[0] {
|
||||||
Some(OpenAIContent::Text(text.clone()))
|
Some(OpenAIContent::Text(text.clone()))
|
||||||
} else {
|
} else {
|
||||||
Some(OpenAIContent::Parts(content_parts))
|
Some(OpenAIContent::Parts(content_parts))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Some(OpenAIContent::Parts(content_parts))
|
Some(OpenAIContent::Parts(content_parts))
|
||||||
|
|
@ -265,13 +262,10 @@ impl OpenAIScheme {
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_build_simple_request() {
|
fn test_build_simple_request() {
|
||||||
let scheme = OpenAIScheme::new();
|
let scheme = OpenAIScheme::new();
|
||||||
let request = Request::new()
|
let request = Request::new().system("System prompt").user("Hello");
|
||||||
.system("System prompt")
|
|
||||||
.user("Hello");
|
|
||||||
|
|
||||||
let body = scheme.build_request("gpt-4o", &request);
|
let body = scheme.build_request("gpt-4o", &request);
|
||||||
|
|
||||||
|
|
@ -303,9 +297,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_build_request_legacy_max_tokens() {
|
fn test_build_request_legacy_max_tokens() {
|
||||||
let scheme = OpenAIScheme::new().with_legacy_max_tokens(true);
|
let scheme = OpenAIScheme::new().with_legacy_max_tokens(true);
|
||||||
let request = Request::new()
|
let request = Request::new().user("Hello").max_tokens(100);
|
||||||
.user("Hello")
|
|
||||||
.max_tokens(100);
|
|
||||||
|
|
||||||
let body = scheme.build_request("llama3", &request);
|
let body = scheme.build_request("llama3", &request);
|
||||||
|
|
||||||
|
|
@ -317,9 +309,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_build_request_modern_max_tokens() {
|
fn test_build_request_modern_max_tokens() {
|
||||||
let scheme = OpenAIScheme::new(); // Default matches modern (legacy=false)
|
let scheme = OpenAIScheme::new(); // Default matches modern (legacy=false)
|
||||||
let request = Request::new()
|
let request = Request::new().user("Hello").max_tokens(100);
|
||||||
.user("Hello")
|
|
||||||
.max_tokens(100);
|
|
||||||
|
|
||||||
let body = scheme.build_request("gpt-4o", &request);
|
let body = scheme.build_request("gpt-4o", &request);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex};
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use tracing::{debug, info, trace, warn};
|
use tracing::{debug, info, trace, warn};
|
||||||
|
|
||||||
|
use crate::Timeline;
|
||||||
use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition};
|
use crate::llm_client::{ClientError, LlmClient, Request, ToolDefinition};
|
||||||
use crate::subscriber_adapter::{
|
use crate::subscriber_adapter::{
|
||||||
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
|
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
|
||||||
|
|
@ -11,7 +12,6 @@ use crate::subscriber_adapter::{
|
||||||
};
|
};
|
||||||
use crate::text_block_collector::TextBlockCollector;
|
use crate::text_block_collector::TextBlockCollector;
|
||||||
use crate::tool_call_collector::ToolCallCollector;
|
use crate::tool_call_collector::ToolCallCollector;
|
||||||
use crate::Timeline;
|
|
||||||
use worker_types::{
|
use worker_types::{
|
||||||
ContentPart, ControlFlow, HookError, Message, MessageContent, Tool, ToolCall, ToolError,
|
ContentPart, ControlFlow, HookError, Message, MessageContent, Tool, ToolCall, ToolError,
|
||||||
ToolResult, TurnResult, WorkerHook, WorkerSubscriber,
|
ToolResult, TurnResult, WorkerHook, WorkerSubscriber,
|
||||||
|
|
@ -442,10 +442,7 @@ impl<C: LlmClient> Worker<C> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Hooks: on_turn_end
|
/// Hooks: on_turn_end
|
||||||
async fn run_on_turn_end_hooks(
|
async fn run_on_turn_end_hooks(&self, messages: &[Message]) -> Result<TurnResult, WorkerError> {
|
||||||
&self,
|
|
||||||
messages: &[Message],
|
|
||||||
) -> Result<TurnResult, WorkerError> {
|
|
||||||
for hook in &self.hooks {
|
for hook in &self.hooks {
|
||||||
let result = hook.on_turn_end(messages).await?;
|
let result = hook.on_turn_end(messages).await?;
|
||||||
match result {
|
match result {
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,13 @@
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{BufRead, BufReader};
|
use std::io::{BufRead, BufReader};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use worker::{Handler, TextBlockEvent, TextBlockKind, Timeline};
|
|
||||||
use worker::llm_client::{ClientError, LlmClient, Request};
|
use worker::llm_client::{ClientError, LlmClient, Request};
|
||||||
|
use worker::{Handler, TextBlockEvent, TextBlockKind, Timeline};
|
||||||
use worker_types::{BlockType, DeltaContent, Event};
|
use worker_types::{BlockType, DeltaContent, Event};
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
@ -51,11 +51,11 @@ impl LlmClient for MockLlmClient {
|
||||||
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
|
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, ClientError>> + Send>>, ClientError> {
|
||||||
let count = self.call_count.fetch_add(1, Ordering::SeqCst);
|
let count = self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||||
if count >= self.responses.len() {
|
if count >= self.responses.len() {
|
||||||
return Err(ClientError::Api {
|
return Err(ClientError::Api {
|
||||||
status: Some(500),
|
status: Some(500),
|
||||||
code: Some("mock_error".to_string()),
|
code: Some("mock_error".to_string()),
|
||||||
message: "No more mock responses".to_string(),
|
message: "No more mock responses".to_string(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
let events = self.responses[count].clone();
|
let events = self.responses[count].clone();
|
||||||
let stream = futures::stream::iter(events.into_iter().map(Ok));
|
let stream = futures::stream::iter(events.into_iter().map(Ok));
|
||||||
|
|
@ -135,7 +135,8 @@ pub fn assert_event_sequence(subdir: &str) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find a text-based fixture
|
// Find a text-based fixture
|
||||||
let fixture_path = fixtures.iter()
|
let fixture_path = fixtures
|
||||||
|
.iter()
|
||||||
.find(|p| p.to_string_lossy().contains("text"))
|
.find(|p| p.to_string_lossy().contains("text"))
|
||||||
.unwrap_or(&fixtures[0]);
|
.unwrap_or(&fixtures[0]);
|
||||||
|
|
||||||
|
|
@ -156,9 +157,9 @@ pub fn assert_event_sequence(subdir: &str) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Event::BlockDelta(delta) => {
|
Event::BlockDelta(delta) => {
|
||||||
if let DeltaContent::Text(_) = &delta.delta {
|
if let DeltaContent::Text(_) = &delta.delta {
|
||||||
delta_found = true;
|
delta_found = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Event::BlockStop(stop) => {
|
Event::BlockStop(stop) => {
|
||||||
if stop.block_type == BlockType::Text {
|
if stop.block_type == BlockType::Text {
|
||||||
|
|
@ -173,9 +174,9 @@ pub fn assert_event_sequence(subdir: &str) {
|
||||||
|
|
||||||
// Check for BlockStart (Warn only for OpenAI/Ollama as it might be missing for text)
|
// Check for BlockStart (Warn only for OpenAI/Ollama as it might be missing for text)
|
||||||
if !start_found {
|
if !start_found {
|
||||||
println!("Warning: No BlockStart found. This is common for OpenAI/Ollama text streams.");
|
println!("Warning: No BlockStart found. This is common for OpenAI/Ollama text streams.");
|
||||||
// For Anthropic, strict start is usually expected, but to keep common logic simple we allow warning.
|
// For Anthropic, strict start is usually expected, but to keep common logic simple we allow warning.
|
||||||
// If specific strictness is needed, we could add a `strict: bool` arg.
|
// If specific strictness is needed, we could add a `strict: bool` arg.
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(delta_found, "Should contain BlockDelta");
|
assert!(delta_found, "Should contain BlockDelta");
|
||||||
|
|
@ -184,7 +185,9 @@ pub fn assert_event_sequence(subdir: &str) {
|
||||||
assert!(stop_found, "Should contain BlockStop for Text block");
|
assert!(stop_found, "Should contain BlockStop for Text block");
|
||||||
} else {
|
} else {
|
||||||
if !stop_found {
|
if !stop_found {
|
||||||
println!(" [Type: ToolUse] BlockStop detection skipped (not explicitly emitted by scheme)");
|
println!(
|
||||||
|
" [Type: ToolUse] BlockStop detection skipped (not explicitly emitted by scheme)"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -200,13 +203,23 @@ pub fn assert_usage_tokens(subdir: &str) {
|
||||||
let events = load_events_from_fixture(&fixture);
|
let events = load_events_from_fixture(&fixture);
|
||||||
let usage_events: Vec<_> = events
|
let usage_events: Vec<_> = events
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|e| if let Event::Usage(u) = e { Some(u) } else { None })
|
.filter_map(|e| {
|
||||||
|
if let Event::Usage(u) = e {
|
||||||
|
Some(u)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if !usage_events.is_empty() {
|
if !usage_events.is_empty() {
|
||||||
let last_usage = usage_events.last().unwrap();
|
let last_usage = usage_events.last().unwrap();
|
||||||
if last_usage.input_tokens.is_some() || last_usage.output_tokens.is_some() {
|
if last_usage.input_tokens.is_some() || last_usage.output_tokens.is_some() {
|
||||||
println!(" Fixture {:?} Usage: {:?}", fixture.file_name(), last_usage);
|
println!(
|
||||||
|
" Fixture {:?} Usage: {:?}",
|
||||||
|
fixture.file_name(),
|
||||||
|
last_usage
|
||||||
|
);
|
||||||
return; // Found valid usage
|
return; // Found valid usage
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -221,7 +234,8 @@ pub fn assert_timeline_integration(subdir: &str) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let fixture_path = fixtures.iter()
|
let fixture_path = fixtures
|
||||||
|
.iter()
|
||||||
.find(|p| p.to_string_lossy().contains("text"))
|
.find(|p| p.to_string_lossy().contains("text"))
|
||||||
.unwrap_or(&fixtures[0]);
|
.unwrap_or(&fixtures[0]);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,16 @@
|
||||||
//!
|
//!
|
||||||
//! Workerが複数のツールを並列に実行することを確認する。
|
//! Workerが複数のツールを並列に実行することを確認する。
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use worker::Worker;
|
use worker::Worker;
|
||||||
use worker_types::{Event, Message, ResponseStatus, StatusEvent, Tool, ToolError, ToolResult, ToolCall, ControlFlow, HookError, WorkerHook};
|
use worker_types::{
|
||||||
|
ControlFlow, Event, HookError, Message, ResponseStatus, StatusEvent, Tool, ToolCall, ToolError,
|
||||||
|
ToolResult, WorkerHook,
|
||||||
|
};
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
use common::MockLlmClient;
|
use common::MockLlmClient;
|
||||||
|
|
@ -105,8 +108,6 @@ async fn test_parallel_tool_execution() {
|
||||||
worker.register_tool(tool2);
|
worker.register_tool(tool2);
|
||||||
worker.register_tool(tool3);
|
worker.register_tool(tool3);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
let messages = vec![Message::user("Run all tools")];
|
let messages = vec![Message::user("Run all tools")];
|
||||||
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
@ -161,7 +162,10 @@ async fn test_before_tool_call_skip() {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WorkerHook for BlockingHook {
|
impl WorkerHook for BlockingHook {
|
||||||
async fn before_tool_call(&self, tool_call: &mut ToolCall) -> Result<ControlFlow, HookError> {
|
async fn before_tool_call(
|
||||||
|
&self,
|
||||||
|
tool_call: &mut ToolCall,
|
||||||
|
) -> Result<ControlFlow, HookError> {
|
||||||
if tool_call.name == "blocked_tool" {
|
if tool_call.name == "blocked_tool" {
|
||||||
Ok(ControlFlow::Skip)
|
Ok(ControlFlow::Skip)
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -176,8 +180,16 @@ async fn test_before_tool_call_skip() {
|
||||||
let _result = worker.run(messages).await;
|
let _result = worker.run(messages).await;
|
||||||
|
|
||||||
// allowed_tool は呼び出されるが、blocked_tool は呼び出されない
|
// allowed_tool は呼び出されるが、blocked_tool は呼び出されない
|
||||||
assert_eq!(allowed_clone.call_count(), 1, "Allowed tool should be called");
|
assert_eq!(
|
||||||
assert_eq!(blocked_clone.call_count(), 0, "Blocked tool should not be called");
|
allowed_clone.call_count(),
|
||||||
|
1,
|
||||||
|
"Allowed tool should be called"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
blocked_clone.call_count(),
|
||||||
|
0,
|
||||||
|
"Blocked tool should not be called"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Hook: after_tool_call で結果が改変されることを確認
|
/// Hook: after_tool_call で結果が改変されることを確認
|
||||||
|
|
@ -212,9 +224,15 @@ async fn test_after_tool_call_modification() {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Tool for SimpleTool {
|
impl Tool for SimpleTool {
|
||||||
fn name(&self) -> &str { "test_tool" }
|
fn name(&self) -> &str {
|
||||||
fn description(&self) -> &str { "Test" }
|
"test_tool"
|
||||||
fn input_schema(&self) -> serde_json::Value { serde_json::json!({}) }
|
}
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Test"
|
||||||
|
}
|
||||||
|
fn input_schema(&self) -> serde_json::Value {
|
||||||
|
serde_json::json!({})
|
||||||
|
}
|
||||||
async fn execute(&self, _: &str) -> Result<String, ToolError> {
|
async fn execute(&self, _: &str) -> Result<String, ToolError> {
|
||||||
Ok("Original Result".to_string())
|
Ok("Original Result".to_string())
|
||||||
}
|
}
|
||||||
|
|
@ -229,7 +247,10 @@ async fn test_after_tool_call_modification() {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl WorkerHook for ModifyingHook {
|
impl WorkerHook for ModifyingHook {
|
||||||
async fn after_tool_call(&self, tool_result: &mut ToolResult) -> Result<ControlFlow, HookError> {
|
async fn after_tool_call(
|
||||||
|
&self,
|
||||||
|
tool_result: &mut ToolResult,
|
||||||
|
) -> Result<ControlFlow, HookError> {
|
||||||
tool_result.content = format!("[Modified] {}", tool_result.content);
|
tool_result.content = format!("[Modified] {}", tool_result.content);
|
||||||
*self.modified_content.lock().unwrap() = Some(tool_result.content.clone());
|
*self.modified_content.lock().unwrap() = Some(tool_result.content.clone());
|
||||||
Ok(ControlFlow::Continue)
|
Ok(ControlFlow::Continue)
|
||||||
|
|
@ -237,7 +258,9 @@ async fn test_after_tool_call_modification() {
|
||||||
}
|
}
|
||||||
|
|
||||||
let modified_content = Arc::new(std::sync::Mutex::new(None));
|
let modified_content = Arc::new(std::sync::Mutex::new(None));
|
||||||
worker.add_hook(ModifyingHook { modified_content: modified_content.clone() });
|
worker.add_hook(ModifyingHook {
|
||||||
|
modified_content: modified_content.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
let messages = vec![Message::user("Test modification")];
|
let messages = vec![Message::user("Test modification")];
|
||||||
let result = worker.run(messages).await;
|
let result = worker.run(messages).await;
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,8 @@
|
||||||
//!
|
//!
|
||||||
//! `#[tool_registry]` と `#[tool]` マクロの動作を確認する。
|
//! `#[tool_registry]` と `#[tool]` マクロの動作を確認する。
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
// マクロ展開に必要なインポート
|
// マクロ展開に必要なインポート
|
||||||
use schemars;
|
use schemars;
|
||||||
|
|
@ -59,12 +59,19 @@ async fn test_basic_tool_generation() {
|
||||||
|
|
||||||
// 説明の確認(docコメントから取得)
|
// 説明の確認(docコメントから取得)
|
||||||
let desc = greet_tool.description();
|
let desc = greet_tool.description();
|
||||||
assert!(desc.contains("メッセージに挨拶を追加する"), "Description should contain doc comment: {}", desc);
|
assert!(
|
||||||
|
desc.contains("メッセージに挨拶を追加する"),
|
||||||
|
"Description should contain doc comment: {}",
|
||||||
|
desc
|
||||||
|
);
|
||||||
|
|
||||||
// スキーマの確認
|
// スキーマの確認
|
||||||
let schema = greet_tool.input_schema();
|
let schema = greet_tool.input_schema();
|
||||||
println!("Schema: {}", serde_json::to_string_pretty(&schema).unwrap());
|
println!("Schema: {}", serde_json::to_string_pretty(&schema).unwrap());
|
||||||
assert!(schema.get("properties").is_some(), "Schema should have properties");
|
assert!(
|
||||||
|
schema.get("properties").is_some(),
|
||||||
|
"Schema should have properties"
|
||||||
|
);
|
||||||
|
|
||||||
// 実行テスト
|
// 実行テスト
|
||||||
let result = greet_tool.execute(r#"{"message": "World"}"#).await;
|
let result = greet_tool.execute(r#"{"message": "World"}"#).await;
|
||||||
|
|
@ -104,7 +111,11 @@ async fn test_no_arguments() {
|
||||||
let result = get_prefix_tool.execute(r#"{}"#).await;
|
let result = get_prefix_tool.execute(r#"{}"#).await;
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
let output = result.unwrap();
|
let output = result.unwrap();
|
||||||
assert!(output.contains("TestPrefix"), "Should contain prefix: {}", output);
|
assert!(
|
||||||
|
output.contains("TestPrefix"),
|
||||||
|
"Should contain prefix: {}",
|
||||||
|
output
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
@ -169,7 +180,11 @@ async fn test_result_return_type_error() {
|
||||||
assert!(result.is_err(), "Should fail for negative value");
|
assert!(result.is_err(), "Should fail for negative value");
|
||||||
|
|
||||||
let err = result.unwrap_err();
|
let err = result.unwrap_err();
|
||||||
assert!(err.to_string().contains("positive"), "Error should mention positive: {}", err);
|
assert!(
|
||||||
|
err.to_string().contains("positive"),
|
||||||
|
"Error should mention positive: {}",
|
||||||
|
err
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@
|
||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use common::MockLlmClient;
|
use common::MockLlmClient;
|
||||||
|
|
@ -67,9 +67,7 @@ impl Tool for MockWeatherTool {
|
||||||
let input: serde_json::Value = serde_json::from_str(input_json)
|
let input: serde_json::Value = serde_json::from_str(input_json)
|
||||||
.map_err(|e| ToolError::InvalidArgument(e.to_string()))?;
|
.map_err(|e| ToolError::InvalidArgument(e.to_string()))?;
|
||||||
|
|
||||||
let city = input["city"]
|
let city = input["city"].as_str().unwrap_or("Unknown");
|
||||||
.as_str()
|
|
||||||
.unwrap_or("Unknown");
|
|
||||||
|
|
||||||
// モックのレスポンスを返す
|
// モックのレスポンスを返す
|
||||||
Ok(format!("Weather in {}: Sunny, 22°C", city))
|
Ok(format!("Weather in {}: Sunny, 22°C", city))
|
||||||
|
|
@ -163,8 +161,6 @@ async fn test_worker_tool_call() {
|
||||||
let tool_for_check = weather_tool.clone();
|
let tool_for_check = weather_tool.clone();
|
||||||
worker.register_tool(weather_tool);
|
worker.register_tool(weather_tool);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// メッセージを送信
|
// メッセージを送信
|
||||||
let messages = vec![worker_types::Message::user("What's the weather in Tokyo?")];
|
let messages = vec![worker_types::Message::user("What's the weather in Tokyo?")];
|
||||||
let _result = worker.run(messages).await;
|
let _result = worker.run(messages).await;
|
||||||
|
|
@ -212,8 +208,8 @@ async fn test_worker_with_programmatic_events() {
|
||||||
/// id, name, input(JSON)を正しく抽出できることを検証する。
|
/// id, name, input(JSON)を正しく抽出できることを検証する。
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_tool_call_collector_integration() {
|
async fn test_tool_call_collector_integration() {
|
||||||
use worker::ToolCallCollector;
|
|
||||||
use worker::Timeline;
|
use worker::Timeline;
|
||||||
|
use worker::ToolCallCollector;
|
||||||
use worker_types::Event;
|
use worker_types::Event;
|
||||||
|
|
||||||
// ToolUseブロックを含むイベントシーケンス
|
// ToolUseブロックを含むイベントシーケンス
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user