use crate::core::LlmClientTrait; use crate::types::WorkerError; use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; use crate::config::UrlConfig; use futures_util::{Stream, StreamExt, TryStreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; use tracing; /// Extract tool name from Tool message content fn extract_tool_name_from_content(content: &str) -> Option { // Look for patterns like "Tool 'tool_name' executed successfully" if let Some(start) = content.find("Tool '") { if let Some(end) = content[start + 6..].find("'") { let tool_name = &content[start + 6..start + 6 + end]; return Some(tool_name.to_string()); } } None } /// Parse tool call information from message content /// Transforms a JSON schema to be compatible with Gemini API /// Converts 'uint' types to 'integer' types and handles nullable types /// Also ensures the schema is in the correct format for Gemini function parameters fn transform_schema_for_gemini(schema: serde_json::Value) -> serde_json::Value { match schema { serde_json::Value::Object(mut obj) => { // Remove $schema key as it's not needed for Gemini obj.remove("$schema"); // Handle type field if let Some(type_val) = obj.get("type") { match type_val { // Convert 'uint' to 'integer' serde_json::Value::String(s) if s == "uint" => { obj.insert( "type".to_string(), serde_json::Value::String("integer".to_string()), ); // Add format for integer types as required by Gemini obj.insert( "format".to_string(), serde_json::Value::String("int64".to_string()), ); } // Handle array types like ["integer", "null"] serde_json::Value::Array(arr) => { if let Some(non_null_type) = arr.iter().find(|&t| t != "null") { // Use the non-null type let mut new_type = non_null_type.clone(); // Convert 'uint' to 'integer' if needed if let serde_json::Value::String(s) = &new_type { if s == "uint" { new_type = serde_json::Value::String("integer".to_string()); } } obj.insert("type".to_string(), new_type.clone()); // Add format for integer types as required by Gemini if let serde_json::Value::String(type_str) = &new_type { if type_str == "integer" { obj.insert( "format".to_string(), serde_json::Value::String("int64".to_string()), ); } } } } // Handle existing integer types serde_json::Value::String(s) if s == "integer" => { // Add format for integer types as required by Gemini obj.insert( "format".to_string(), serde_json::Value::String("int64".to_string()), ); } _ => {} } } // Handle properties and required fields if let (Some(properties), Some(required)) = (obj.get("properties"), obj.get("required")) { if let (serde_json::Value::Object(props), serde_json::Value::Array(req_arr)) = (properties, required) { let mut new_required = Vec::new(); for (prop_name, _) in props { // Only include in required if it's not nullable if req_arr.iter().any(|r| r == prop_name) { // Check if this property has a nullable type if let Some(prop_schema) = props.get(prop_name) { if let Some(prop_type) = prop_schema.get("type") { // If type is an array containing "null", it's nullable let is_nullable = match prop_type { serde_json::Value::Array(arr) => { arr.iter().any(|t| t == "null") } _ => false, }; // Only add to required if not nullable if !is_nullable { new_required .push(serde_json::Value::String(prop_name.clone())); } } else { // No type info, assume required new_required.push(serde_json::Value::String(prop_name.clone())); } } } } obj.insert( "required".to_string(), serde_json::Value::Array(new_required), ); } } // Recursively transform nested objects for (_, value) in obj.iter_mut() { *value = transform_schema_for_gemini(value.clone()); } serde_json::Value::Object(obj) } serde_json::Value::Array(arr) => { serde_json::Value::Array(arr.into_iter().map(transform_schema_for_gemini).collect()) } other => other, } } // --- Request Structures --- #[derive(Debug, Serialize, Clone)] pub struct GeminiTool { #[serde(rename = "functionDeclarations")] pub function_declarations: Vec, } #[derive(Debug, Serialize, Clone)] pub struct GeminiFunctionDeclaration { pub name: String, pub description: String, pub parameters: serde_json::Value, } #[derive(Debug, Serialize, Clone)] pub struct GeminiRequest { pub contents: Vec, #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "systemInstruction")] pub system_instruction: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option>, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct GeminiContent { pub role: String, #[serde(default)] pub parts: Vec, } #[derive(Debug, Serialize, Deserialize, Clone)] #[serde(untagged)] pub enum GeminiPart { Text { text: String, }, FunctionCall { #[serde(rename = "functionCall")] function_call: GeminiFunctionCall, }, FunctionResponse { #[serde(rename = "functionResponse")] function_response: GeminiFunctionResponse, }, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct GeminiFunctionCall { pub name: String, pub args: serde_json::Value, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct GeminiFunctionResponse { pub name: String, pub response: serde_json::Value, } // --- Response Structures --- #[derive(Debug, Deserialize, Clone)] #[serde(rename_all = "camelCase")] pub struct GeminiResponse { #[serde(default)] pub candidates: Vec, } #[derive(Debug, Deserialize, Clone)] #[serde(rename_all = "camelCase")] pub struct GeminiCandidate { pub content: GeminiContent, #[serde(skip_serializing_if = "Option::is_none")] pub finish_reason: Option, } fn build_url(model: &str) -> String { let base_url = UrlConfig::get_base_url("gemini"); let action = "streamGenerateContent"; format!("{}/v1beta/models/{}:{}", base_url, model, action) } /// Finds the start and end indices of the first complete JSON object `{...}` in the buffer. fn find_first_json_object_bounds(buffer: &[u8]) -> Option<(usize, usize)> { if let Some(start) = buffer.iter().position(|&b| b == b'{') { let mut brace_count = 0; let mut in_string = false; let mut escaped = false; for (i, &byte) in buffer.iter().skip(start).enumerate() { if in_string { if escaped { escaped = false; } else if byte == b'\\' { escaped = true; } else if byte == b'"' { in_string = false; } } else { match byte { b'"' => in_string = true, b'{' => brace_count += 1, b'}' => { brace_count -= 1; if brace_count == 0 { let end = start + i + 1; return Some((start, end)); } } _ => {} } } } } None // No complete object found } /// Completes a chat request with streaming, yielding StreamEvent objects. pub(crate) fn stream_events<'a>( api_key: &'a str, model: &'a str, request: GeminiRequest, llm_debug: Option, ) -> impl Stream> + 'a { let api_key = api_key.to_string(); let model = model.to_string(); async_stream::try_stream! { let body = serde_json::to_string_pretty(&request).unwrap_or_else(|e| e.to_string()); tracing::debug!("Gemini Request Body: {}", body); if let Some(debug_settings) = &llm_debug { if let Some(debug_event) = debug_settings.debug_request(&model, "Gemini", &serde_json::to_value(&request).unwrap_or_default()) { yield debug_event; } } let client = Client::new(); let url = build_url(&model); let response = client .post(&url) .header("x-goog-api-key", &api_key) .json(&request) .send() .await .map_err(|e| anyhow::anyhow!("Gemini API request failed: {}", e))?; let status = response.status(); if !status.is_success() { let error_body = response.text().await.unwrap_or_else(|_| "Could not read error body".to_string()); let error_msg = format!("Gemini API request failed with status: {} - {}", status, error_body); tracing::error!("{}", error_msg); Err(anyhow::anyhow!(error_msg))?; } else { let mut byte_stream = response.bytes_stream(); let mut buffer = Vec::new(); let mut full_content = String::new(); while let Some(chunk_result) = byte_stream.next().await { let chunk = chunk_result?; buffer.extend_from_slice(&chunk); while let Some((start, end)) = find_first_json_object_bounds(&buffer) { let json_slice = &buffer[start..end]; if let Some(debug_settings) = &llm_debug { if let Ok(response_value) = serde_json::from_slice::(json_slice) { if let Some(debug_event) = debug_settings.debug_response(&model, "Gemini", &response_value) { yield debug_event; } } } match serde_json::from_slice::(json_slice) { Ok(response) => { let response_text = String::from_utf8_lossy(json_slice); tracing::debug!( response = %response_text, candidates_count = response.candidates.len(), "Successfully parsed Gemini response" ); if response.candidates.is_empty() { tracing::warn!( response = %response_text, "Received empty candidates in Gemini response" ); } else if let Some(candidate) = response.candidates.get(0) { // Log finish reason for debugging if let Some(ref finish_reason) = candidate.finish_reason { tracing::debug!( finish_reason = %finish_reason, "Received finish reason in Gemini response" ); // Handle specific finish reasons match finish_reason.as_str() { "STOP" => { tracing::debug!("Gemini response completed with STOP"); // Continue processing parts if any, this is normal completion } "MAX_TOKENS" => { tracing::warn!("Gemini response stopped due to MAX_TOKENS"); } "SAFETY" => { tracing::warn!("Gemini response stopped due to SAFETY concerns"); } "RECITATION" => { tracing::warn!("Gemini response stopped due to RECITATION"); } other => { tracing::warn!("Gemini response stopped with unknown reason: {}", other); } } } if candidate.content.parts.is_empty() { tracing::warn!( response = %response_text, role = %candidate.content.role, finish_reason = ?candidate.finish_reason, "Received empty parts in Gemini response" ); } else { for part in &candidate.content.parts { tracing::debug!("Processing Gemini part (type unknown)"); match part { GeminiPart::Text { text } => { tracing::debug!("Found Text part with content length: {}", text.len()); full_content.push_str(text); yield StreamEvent::Chunk(text.clone()); } GeminiPart::FunctionCall { function_call } => { tracing::debug!("Found FunctionCall part: name={}, args={:?}", function_call.name, function_call.args); let tool_call = ToolCall { name: function_call.name.clone(), arguments: serde_json::to_string(&function_call.args) .unwrap_or_else(|_| "{}".to_string()), }; yield StreamEvent::ToolCall(tool_call.clone()); } GeminiPart::FunctionResponse { .. } => { // Function responses in model output are not expected // as they're part of the input conversation history tracing::warn!("Unexpected FunctionResponse in model output"); } } } } } } Err(e) => { let response_text = String::from_utf8_lossy(json_slice); tracing::warn!( error = %e, response = %response_text, "Failed to deserialize GeminiResponse from slice" ); } } buffer.drain(..end); } } let final_message = Message::new( Role::Model, full_content, ); yield StreamEvent::Completion(final_message); } } } pub struct GeminiClient { api_key: String, model: String, } impl GeminiClient { pub fn new(api_key: &str, model: &str) -> Self { Self { api_key: api_key.to_string(), model: model.to_string(), } } pub fn get_model_name(&self) -> String { self.model.clone() } /// 静的メソッド:API キーを受け取ってモデル一覧を取得 pub async fn list_models_static( api_key: &str, ) -> Result, WorkerError> { let client = Client::new(); let url = UrlConfig::get_models_url("gemini"); let response = client .get(url) .header("x-goog-api-key", api_key) .send() .await .map_err(|e| { tracing::error!("Gemini API request failed: {}", e); WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini) })?; if !response.status().is_success() { let status = response.status(); let error_body = response.text().await.unwrap_or_default(); tracing::error!( "Gemini list_models_static failed - Status: {}, Body: {}", status, error_body ); return Err(WorkerError::from_api_error( format!("Failed to list Gemini models: {} - {}", status, error_body), &crate::types::LlmProvider::Gemini, )); } let models_response: serde_json::Value = response.json().await.map_err(|e| { WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini) })?; let mut models = Vec::new(); if let Some(models_array) = models_response.get("models").and_then(|m| m.as_array()) { for model in models_array { if let Some(name) = model.get("name").and_then(|n| n.as_str()) { // "models/" プレフィックスを除去 let model_id = name.strip_prefix("models/").unwrap_or(name); // generateContentメソッドをサポートするモデルのみを含める if let Some(supported_methods) = model .get("supportedGenerationMethods") .and_then(|m| m.as_array()) { let supports_generate_content = supported_methods .iter() .any(|method| method.as_str() == Some("generateContent")); if supports_generate_content { models.push(crate::types::ModelInfo { id: model_id.to_string(), name: model .get("displayName") .and_then(|d| d.as_str()) .unwrap_or(model_id) .to_string(), provider: crate::types::LlmProvider::Gemini, supports_tools: true, supports_function_calling: true, supports_vision: false, supports_multimodal: false, context_length: None, training_cutoff: None, capabilities: vec!["text_generation".to_string()], description: model .get("description") .and_then(|d| d.as_str()) .map(|s| s.to_string()) .or_else(|| Some(format!("Google Gemini model: {}", model_id))), }); } } } } } tracing::info!( "Gemini list_models_static found {} models with metadata", models.len() ); Ok(models) } pub async fn chat_stream<'a>( &'a self, messages: Vec, tools: Option<&[DynamicToolDefinition]>, llm_debug: Option, ) -> Result< Box> + Unpin + Send + 'a>, WorkerError, > { // Separate system messages from regular messages let (system_messages, regular_messages): (Vec<_>, Vec<_>) = messages .into_iter() .partition(|msg| matches!(msg.role, Role::System)); // Create system instruction from system messages let system_instruction = if !system_messages.is_empty() { let combined_system_content = system_messages .into_iter() .map(|msg| msg.content) .collect::>() .join("\n\n"); Some(GeminiContent { role: "user".to_string(), // System instruction uses "user" role parts: vec![GeminiPart::Text { text: combined_system_content, }], }) } else { None }; // Process regular messages with proper tool context handling let contents = regular_messages .into_iter() .map(|msg| { let (role, parts) = match msg.role { Role::User => ( "user".to_string(), vec![GeminiPart::Text { text: msg.content }], ), Role::Model => { if let Some(tool_calls) = &msg.tool_calls { // Model message with tool calls - convert to FunctionCall parts tracing::debug!( "Converting model message with {} tool calls to FunctionCall parts", tool_calls.len() ); let mut parts = Vec::new(); // Add text content if present if !msg.content.is_empty() { parts.push(GeminiPart::Text { text: msg.content.clone(), }); } // Add function calls for tool_call in tool_calls { tracing::debug!( "Adding FunctionCall part for tool: {}", tool_call.name ); let args = serde_json::from_str(&tool_call.arguments) .unwrap_or(serde_json::json!({})); parts.push(GeminiPart::FunctionCall { function_call: GeminiFunctionCall { name: tool_call.name.clone(), args, }, }); } ("model".to_string(), parts) } else { // Regular model message tracing::debug!("Converting regular model message (no tool calls)"); ( "model".to_string(), vec![GeminiPart::Text { text: msg.content }], ) } } Role::Tool => { // Tool responses should be sent as FunctionResponse if let Some(tool_name) = extract_tool_name_from_content(&msg.content) { // Extract result from the content let result_value = if msg.content.contains("Result: ") { if let Some(result_start) = msg.content.find("Result: ") { let result_str = &msg.content[result_start + 8..]; // Try to parse as JSON, fallback to string serde_json::from_str(result_str) .unwrap_or_else(|_| serde_json::json!(result_str)) } else { serde_json::json!(msg.content) } } else { serde_json::json!(msg.content) }; ( "user".to_string(), vec![GeminiPart::FunctionResponse { function_response: GeminiFunctionResponse { name: tool_name, response: result_value, }, }], ) } else { // Fallback to text response if tool name can't be extracted ( "user".to_string(), vec![GeminiPart::Text { text: format!("Tool Response:\n{}", msg.content), }], ) } } Role::System => unreachable!(), // Should not reach here after partition }; GeminiContent { role, parts } }) .collect(); let tools = tools.map(|tools| { vec![GeminiTool { function_declarations: tools .iter() .map(|tool| { let mut transformed_schema = transform_schema_for_gemini(tool.parameters_schema.clone()); // Ensure the schema has the correct structure for Gemini match transformed_schema { serde_json::Value::Object(ref mut obj) => { // Gemini expects the parameters to be an object with type: "object" if !obj.contains_key("type") { obj.insert( "type".to_string(), serde_json::Value::String("object".to_string()), ); } // If there are no properties, add an empty object if !obj.contains_key("properties") { obj.insert( "properties".to_string(), serde_json::Value::Object(serde_json::Map::new()), ); } // If there are no required fields, add an empty array if !obj.contains_key("required") { obj.insert( "required".to_string(), serde_json::Value::Array(vec![]), ); } } _ => { // If it's not an object, create a proper object schema let mut schema_obj = serde_json::Map::new(); schema_obj.insert( "type".to_string(), serde_json::Value::String("object".to_string()), ); schema_obj.insert( "properties".to_string(), serde_json::Value::Object(serde_json::Map::new()), ); schema_obj.insert( "required".to_string(), serde_json::Value::Array(vec![]), ); transformed_schema = serde_json::Value::Object(schema_obj); } } GeminiFunctionDeclaration { name: tool.name.clone(), description: tool.description.clone(), parameters: transformed_schema, } }) .collect(), }] }); let request = GeminiRequest { contents, system_instruction, tools, }; let stream = stream_events(&self.api_key, &self.model, request, llm_debug) .map_err(|e| WorkerError::llm_api("gemini", e.to_string())); Ok(Box::new(Box::pin(stream))) } pub async fn get_model_details( &self, model_name: &str, ) -> Result { let client = Client::new(); let url = UrlConfig::get_model_url("gemini", model_name); let response = client .get(&url) .header("x-goog-api-key", &self.api_key) .send() .await .map_err(|e| { WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini) })?; if !response.status().is_success() { return Err(WorkerError::from_api_error( format!( "Gemini model details request failed with status: {}", response.status() ), &crate::types::LlmProvider::Gemini, )); } let model_data: serde_json::Value = response.json().await.map_err(|e| { WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini) })?; let name = model_data .get("name") .and_then(|n| n.as_str()) .unwrap_or(model_name); let display_name = model_data .get("displayName") .and_then(|d| d.as_str()) .unwrap_or(name); let description = model_data .get("description") .and_then(|d| d.as_str()) .unwrap_or(""); let version = model_data .get("version") .and_then(|v| v.as_str()) .map(|s| s.to_string()); let input_token_limit = model_data .get("inputTokenLimit") .and_then(|i| i.as_u64()) .map(|i| i as u32); let _output_token_limit = model_data .get("outputTokenLimit") .and_then(|o| o.as_u64()) .map(|o| o as u32); let empty_vec = Vec::new(); let supported_methods = model_data .get("supportedGenerationMethods") .and_then(|s| s.as_array()) .unwrap_or(&empty_vec); let supports_tools = supported_methods .iter() .any(|method| method.as_str() == Some("generateContent")); let supports_vision = false; // Will be determined dynamically let capabilities = vec!["text_generation".to_string()]; // Basic default Ok(crate::types::ModelInfo { id: model_name.to_string(), name: display_name.to_string(), provider: crate::types::LlmProvider::Gemini, supports_tools, supports_function_calling: supports_tools, supports_vision, supports_multimodal: supports_vision, context_length: input_token_limit, training_cutoff: version, capabilities, description: Some(if description.is_empty() { format!("Google Gemini model: {}", display_name) } else { description.to_string() }), }) } pub async fn check_connection(&self) -> Result<(), WorkerError> { // Simple connection check - try to call the API // For now, just return OK if model is not empty if self.model.is_empty() { return Err(WorkerError::model_not_found("gemini", "No model specified")); } Ok(()) } } #[async_trait::async_trait] impl LlmClientTrait for GeminiClient { async fn chat_stream<'a>( &'a self, messages: Vec, tools: Option<&[DynamicToolDefinition]>, llm_debug: Option, ) -> Result< Box> + Unpin + Send + 'a>, WorkerError, > { self.chat_stream(messages, tools, llm_debug).await } async fn check_connection(&self) -> Result<(), WorkerError> { self.check_connection().await } fn provider(&self) -> LlmProvider { LlmProvider::Gemini } fn get_model_name(&self) -> String { self.get_model_name() } } #[cfg(test)] mod tests { use super::*; use std::env; #[test] fn test_schema_transformation() { // Test schema with various type formats including $schema let schema = serde_json::json!({ "$schema": "https://json-schema.org/draft/2020-12/schema", "type": "object", "properties": { "id": { "type": "uint" }, "optional_number": { "type": ["integer", "null"] }, "required_string": { "type": "string" }, "existing_integer": { "type": "integer" }, "nested": { "type": "object", "properties": { "count": { "type": ["uint", "null"] } } } }, "required": ["id", "optional_number", "required_string"] }); let transformed = transform_schema_for_gemini(schema); // Check that the schema has the correct structure assert_eq!(transformed["type"], "object"); assert!(transformed["properties"].is_object()); assert!(transformed["required"].is_array()); // Check that $schema key is removed assert!(transformed.get("$schema").is_none()); // Check that 'uint' was transformed to 'integer' assert_eq!(transformed["properties"]["id"]["type"], "integer"); assert_eq!(transformed["properties"]["id"]["format"], "int64"); // Check that array types are converted to single types assert_eq!( transformed["properties"]["optional_number"]["type"], "integer" ); assert_eq!( transformed["properties"]["optional_number"]["format"], "int64" ); assert_eq!( transformed["properties"]["nested"]["properties"]["count"]["type"], "integer" ); assert_eq!( transformed["properties"]["nested"]["properties"]["count"]["format"], "int64" ); // Check that existing integer types also get format assert_eq!( transformed["properties"]["existing_integer"]["type"], "integer" ); assert_eq!( transformed["properties"]["existing_integer"]["format"], "int64" ); // Check that required array is updated correctly (nullable properties should be removed) let required: Vec<&str> = transformed["required"] .as_array() .unwrap() .iter() .map(|v| v.as_str().unwrap()) .collect(); assert!(required.contains(&"id")); assert!(required.contains(&"required_string")); assert!(!required.contains(&"optional_number")); // Should be removed because it's nullable } #[test] fn test_empty_schema_transformation() { // Test with an empty schema as would be processed in tool generation let schema = serde_json::json!({}); let mut transformed = transform_schema_for_gemini(schema); // Apply the same logic as in tool generation match transformed { serde_json::Value::Object(ref mut obj) => { if !obj.contains_key("type") { obj.insert( "type".to_string(), serde_json::Value::String("object".to_string()), ); } if !obj.contains_key("properties") { obj.insert( "properties".to_string(), serde_json::Value::Object(serde_json::Map::new()), ); } if !obj.contains_key("required") { obj.insert("required".to_string(), serde_json::Value::Array(vec![])); } } _ => { let mut schema_obj = serde_json::Map::new(); schema_obj.insert( "type".to_string(), serde_json::Value::String("object".to_string()), ); schema_obj.insert( "properties".to_string(), serde_json::Value::Object(serde_json::Map::new()), ); schema_obj.insert("required".to_string(), serde_json::Value::Array(vec![])); transformed = serde_json::Value::Object(schema_obj); } } // Should be converted to a proper object schema assert_eq!(transformed["type"], "object"); assert!(transformed["properties"].is_object()); assert!(transformed["required"].is_array()); } }