llm-worker-rs/worker/src/llm/gemini.rs

977 lines
40 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::core::LlmClientTrait;
use crate::types::WorkerError;
use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
use crate::config::UrlConfig;
use futures_util::{Stream, StreamExt, TryStreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing;
/// Extract tool name from Tool message content
fn extract_tool_name_from_content(content: &str) -> Option<String> {
// Look for patterns like "Tool 'tool_name' executed successfully"
if let Some(start) = content.find("Tool '") {
if let Some(end) = content[start + 6..].find("'") {
let tool_name = &content[start + 6..start + 6 + end];
return Some(tool_name.to_string());
}
}
None
}
/// Parse tool call information from message content
/// Transforms a JSON schema to be compatible with Gemini API
/// Converts 'uint' types to 'integer' types and handles nullable types
/// Also ensures the schema is in the correct format for Gemini function parameters
fn transform_schema_for_gemini(schema: serde_json::Value) -> serde_json::Value {
match schema {
serde_json::Value::Object(mut obj) => {
// Remove $schema key as it's not needed for Gemini
obj.remove("$schema");
// Handle type field
if let Some(type_val) = obj.get("type") {
match type_val {
// Convert 'uint' to 'integer'
serde_json::Value::String(s) if s == "uint" => {
obj.insert(
"type".to_string(),
serde_json::Value::String("integer".to_string()),
);
// Add format for integer types as required by Gemini
obj.insert(
"format".to_string(),
serde_json::Value::String("int64".to_string()),
);
}
// Handle array types like ["integer", "null"]
serde_json::Value::Array(arr) => {
if let Some(non_null_type) = arr.iter().find(|&t| t != "null") {
// Use the non-null type
let mut new_type = non_null_type.clone();
// Convert 'uint' to 'integer' if needed
if let serde_json::Value::String(s) = &new_type {
if s == "uint" {
new_type = serde_json::Value::String("integer".to_string());
}
}
obj.insert("type".to_string(), new_type.clone());
// Add format for integer types as required by Gemini
if let serde_json::Value::String(type_str) = &new_type {
if type_str == "integer" {
obj.insert(
"format".to_string(),
serde_json::Value::String("int64".to_string()),
);
}
}
}
}
// Handle existing integer types
serde_json::Value::String(s) if s == "integer" => {
// Add format for integer types as required by Gemini
obj.insert(
"format".to_string(),
serde_json::Value::String("int64".to_string()),
);
}
_ => {}
}
}
// Handle properties and required fields
if let (Some(properties), Some(required)) = (obj.get("properties"), obj.get("required"))
{
if let (serde_json::Value::Object(props), serde_json::Value::Array(req_arr)) =
(properties, required)
{
let mut new_required = Vec::new();
for (prop_name, _) in props {
// Only include in required if it's not nullable
if req_arr.iter().any(|r| r == prop_name) {
// Check if this property has a nullable type
if let Some(prop_schema) = props.get(prop_name) {
if let Some(prop_type) = prop_schema.get("type") {
// If type is an array containing "null", it's nullable
let is_nullable = match prop_type {
serde_json::Value::Array(arr) => {
arr.iter().any(|t| t == "null")
}
_ => false,
};
// Only add to required if not nullable
if !is_nullable {
new_required
.push(serde_json::Value::String(prop_name.clone()));
}
} else {
// No type info, assume required
new_required.push(serde_json::Value::String(prop_name.clone()));
}
}
}
}
obj.insert(
"required".to_string(),
serde_json::Value::Array(new_required),
);
}
}
// Recursively transform nested objects
for (_, value) in obj.iter_mut() {
*value = transform_schema_for_gemini(value.clone());
}
serde_json::Value::Object(obj)
}
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.into_iter().map(transform_schema_for_gemini).collect())
}
other => other,
}
}
// --- Request Structures ---
#[derive(Debug, Serialize, Clone)]
pub struct GeminiTool {
#[serde(rename = "functionDeclarations")]
pub function_declarations: Vec<GeminiFunctionDeclaration>,
}
#[derive(Debug, Serialize, Clone)]
pub struct GeminiFunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Clone)]
pub struct GeminiRequest {
pub contents: Vec<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "systemInstruction")]
pub system_instruction: Option<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<GeminiTool>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GeminiContent {
pub role: String,
#[serde(default)]
pub parts: Vec<GeminiPart>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(untagged)]
pub enum GeminiPart {
Text {
text: String,
},
FunctionCall {
#[serde(rename = "functionCall")]
function_call: GeminiFunctionCall,
},
FunctionResponse {
#[serde(rename = "functionResponse")]
function_response: GeminiFunctionResponse,
},
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GeminiFunctionCall {
pub name: String,
pub args: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GeminiFunctionResponse {
pub name: String,
pub response: serde_json::Value,
}
// --- Response Structures ---
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct GeminiResponse {
#[serde(default)]
pub candidates: Vec<GeminiCandidate>,
}
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct GeminiCandidate {
pub content: GeminiContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}
fn build_url(model: &str) -> String {
let base_url = UrlConfig::get_base_url("gemini");
let action = "streamGenerateContent";
format!("{}/v1beta/models/{}:{}", base_url, model, action)
}
/// Finds the start and end indices of the first complete JSON object `{...}` in the buffer.
fn find_first_json_object_bounds(buffer: &[u8]) -> Option<(usize, usize)> {
if let Some(start) = buffer.iter().position(|&b| b == b'{') {
let mut brace_count = 0;
let mut in_string = false;
let mut escaped = false;
for (i, &byte) in buffer.iter().skip(start).enumerate() {
if in_string {
if escaped {
escaped = false;
} else if byte == b'\\' {
escaped = true;
} else if byte == b'"' {
in_string = false;
}
} else {
match byte {
b'"' => in_string = true,
b'{' => brace_count += 1,
b'}' => {
brace_count -= 1;
if brace_count == 0 {
let end = start + i + 1;
return Some((start, end));
}
}
_ => {}
}
}
}
}
None // No complete object found
}
/// Completes a chat request with streaming, yielding StreamEvent objects.
pub(crate) fn stream_events<'a>(
api_key: &'a str,
model: &'a str,
request: GeminiRequest,
llm_debug: Option<crate::types::LlmDebug>,
) -> impl Stream<Item = anyhow::Result<StreamEvent>> + 'a {
let api_key = api_key.to_string();
let model = model.to_string();
async_stream::try_stream! {
let body = serde_json::to_string_pretty(&request).unwrap_or_else(|e| e.to_string());
tracing::debug!("Gemini Request Body: {}", body);
if let Some(debug_settings) = &llm_debug {
if let Some(debug_event) = debug_settings.debug_request(&model, "Gemini", &serde_json::to_value(&request).unwrap_or_default()) {
yield debug_event;
}
}
let client = Client::new();
let url = build_url(&model);
let response = client
.post(&url)
.header("x-goog-api-key", &api_key)
.json(&request)
.send()
.await
.map_err(|e| anyhow::anyhow!("Gemini API request failed: {}", e))?;
let status = response.status();
if !status.is_success() {
let error_body = response.text().await.unwrap_or_else(|_| "Could not read error body".to_string());
let error_msg = format!("Gemini API request failed with status: {} - {}", status, error_body);
tracing::error!("{}", error_msg);
Err(anyhow::anyhow!(error_msg))?;
} else {
let mut byte_stream = response.bytes_stream();
let mut buffer = Vec::new();
let mut full_content = String::new();
while let Some(chunk_result) = byte_stream.next().await {
let chunk = chunk_result?;
buffer.extend_from_slice(&chunk);
while let Some((start, end)) = find_first_json_object_bounds(&buffer) {
let json_slice = &buffer[start..end];
if let Some(debug_settings) = &llm_debug {
if let Ok(response_value) = serde_json::from_slice::<serde_json::Value>(json_slice) {
if let Some(debug_event) = debug_settings.debug_response(&model, "Gemini", &response_value) {
yield debug_event;
}
}
}
match serde_json::from_slice::<GeminiResponse>(json_slice) {
Ok(response) => {
let response_text = String::from_utf8_lossy(json_slice);
tracing::debug!(
response = %response_text,
candidates_count = response.candidates.len(),
"Successfully parsed Gemini response"
);
if response.candidates.is_empty() {
tracing::warn!(
response = %response_text,
"Received empty candidates in Gemini response"
);
} else if let Some(candidate) = response.candidates.get(0) {
// Log finish reason for debugging
if let Some(ref finish_reason) = candidate.finish_reason {
tracing::debug!(
finish_reason = %finish_reason,
"Received finish reason in Gemini response"
);
// Handle specific finish reasons
match finish_reason.as_str() {
"STOP" => {
tracing::debug!("Gemini response completed with STOP");
// Continue processing parts if any, this is normal completion
}
"MAX_TOKENS" => {
tracing::warn!("Gemini response stopped due to MAX_TOKENS");
}
"SAFETY" => {
tracing::warn!("Gemini response stopped due to SAFETY concerns");
}
"RECITATION" => {
tracing::warn!("Gemini response stopped due to RECITATION");
}
other => {
tracing::warn!("Gemini response stopped with unknown reason: {}", other);
}
}
}
if candidate.content.parts.is_empty() {
tracing::warn!(
response = %response_text,
role = %candidate.content.role,
finish_reason = ?candidate.finish_reason,
"Received empty parts in Gemini response"
);
} else {
for part in &candidate.content.parts {
tracing::debug!("Processing Gemini part (type unknown)");
match part {
GeminiPart::Text { text } => {
tracing::debug!("Found Text part with content length: {}", text.len());
full_content.push_str(text);
yield StreamEvent::Chunk(text.clone());
}
GeminiPart::FunctionCall { function_call } => {
tracing::debug!("Found FunctionCall part: name={}, args={:?}", function_call.name, function_call.args);
let tool_call = ToolCall {
name: function_call.name.clone(),
arguments: serde_json::to_string(&function_call.args)
.unwrap_or_else(|_| "{}".to_string()),
};
yield StreamEvent::ToolCall(tool_call.clone());
}
GeminiPart::FunctionResponse { .. } => {
// Function responses in model output are not expected
// as they're part of the input conversation history
tracing::warn!("Unexpected FunctionResponse in model output");
}
}
}
}
}
}
Err(e) => {
let response_text = String::from_utf8_lossy(json_slice);
tracing::warn!(
error = %e,
response = %response_text,
"Failed to deserialize GeminiResponse from slice"
);
}
}
buffer.drain(..end);
}
}
let final_message = Message::new(
Role::Model,
full_content,
);
yield StreamEvent::Completion(final_message);
}
}
}
pub struct GeminiClient {
api_key: String,
model: String,
}
impl GeminiClient {
pub fn new(api_key: &str, model: &str) -> Self {
Self {
api_key: api_key.to_string(),
model: model.to_string(),
}
}
pub fn get_model_name(&self) -> String {
self.model.clone()
}
/// 静的メソッドAPI キーを受け取ってモデル一覧を取得
pub async fn list_models_static(
api_key: &str,
) -> Result<Vec<crate::types::ModelInfo>, WorkerError> {
let client = Client::new();
let url = UrlConfig::get_models_url("gemini");
let response = client
.get(url)
.header("x-goog-api-key", api_key)
.send()
.await
.map_err(|e| {
tracing::error!("Gemini API request failed: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini)
})?;
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
tracing::error!(
"Gemini list_models_static failed - Status: {}, Body: {}",
status,
error_body
);
return Err(WorkerError::from_api_error(
format!("Failed to list Gemini models: {} - {}", status, error_body),
&crate::types::LlmProvider::Gemini,
));
}
let models_response: serde_json::Value = response.json().await.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini)
})?;
let mut models = Vec::new();
if let Some(models_array) = models_response.get("models").and_then(|m| m.as_array()) {
for model in models_array {
if let Some(name) = model.get("name").and_then(|n| n.as_str()) {
// "models/" プレフィックスを除去
let model_id = name.strip_prefix("models/").unwrap_or(name);
// generateContentメソッドをサポートするモデルのみを含める
if let Some(supported_methods) = model
.get("supportedGenerationMethods")
.and_then(|m| m.as_array())
{
let supports_generate_content = supported_methods
.iter()
.any(|method| method.as_str() == Some("generateContent"));
if supports_generate_content {
models.push(crate::types::ModelInfo {
id: model_id.to_string(),
name: model
.get("displayName")
.and_then(|d| d.as_str())
.unwrap_or(model_id)
.to_string(),
provider: crate::types::LlmProvider::Gemini,
supports_tools: true,
supports_function_calling: true,
supports_vision: false,
supports_multimodal: false,
context_length: None,
training_cutoff: None,
capabilities: vec!["text_generation".to_string()],
description: model
.get("description")
.and_then(|d| d.as_str())
.map(|s| s.to_string())
.or_else(|| Some(format!("Google Gemini model: {}", model_id))),
});
}
}
}
}
}
tracing::info!(
"Gemini list_models_static found {} models with metadata",
models.len()
);
Ok(models)
}
pub async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
// Separate system messages from regular messages
let (system_messages, regular_messages): (Vec<_>, Vec<_>) = messages
.into_iter()
.partition(|msg| matches!(msg.role, Role::System));
// Create system instruction from system messages
let system_instruction = if !system_messages.is_empty() {
let combined_system_content = system_messages
.into_iter()
.map(|msg| msg.content)
.collect::<Vec<_>>()
.join("\n\n");
Some(GeminiContent {
role: "user".to_string(), // System instruction uses "user" role
parts: vec![GeminiPart::Text {
text: combined_system_content,
}],
})
} else {
None
};
// Process regular messages with proper tool context handling
let contents = regular_messages
.into_iter()
.map(|msg| {
let (role, parts) = match msg.role {
Role::User => (
"user".to_string(),
vec![GeminiPart::Text { text: msg.content }],
),
Role::Model => {
if let Some(tool_calls) = &msg.tool_calls {
// Model message with tool calls - convert to FunctionCall parts
tracing::debug!(
"Converting model message with {} tool calls to FunctionCall parts",
tool_calls.len()
);
let mut parts = Vec::new();
// Add text content if present
if !msg.content.is_empty() {
parts.push(GeminiPart::Text {
text: msg.content.clone(),
});
}
// Add function calls
for tool_call in tool_calls {
tracing::debug!(
"Adding FunctionCall part for tool: {}",
tool_call.name
);
let args = serde_json::from_str(&tool_call.arguments)
.unwrap_or(serde_json::json!({}));
parts.push(GeminiPart::FunctionCall {
function_call: GeminiFunctionCall {
name: tool_call.name.clone(),
args,
},
});
}
("model".to_string(), parts)
} else {
// Regular model message
tracing::debug!("Converting regular model message (no tool calls)");
(
"model".to_string(),
vec![GeminiPart::Text { text: msg.content }],
)
}
}
Role::Tool => {
// Tool responses should be sent as FunctionResponse
if let Some(tool_name) = extract_tool_name_from_content(&msg.content) {
// Extract result from the content
let result_value = if msg.content.contains("Result: ") {
if let Some(result_start) = msg.content.find("Result: ") {
let result_str = &msg.content[result_start + 8..];
// Try to parse as JSON, fallback to string
serde_json::from_str(result_str)
.unwrap_or_else(|_| serde_json::json!(result_str))
} else {
serde_json::json!(msg.content)
}
} else {
serde_json::json!(msg.content)
};
(
"user".to_string(),
vec![GeminiPart::FunctionResponse {
function_response: GeminiFunctionResponse {
name: tool_name,
response: result_value,
},
}],
)
} else {
// Fallback to text response if tool name can't be extracted
(
"user".to_string(),
vec![GeminiPart::Text {
text: format!("Tool Response:\n{}", msg.content),
}],
)
}
}
Role::System => unreachable!(), // Should not reach here after partition
};
GeminiContent { role, parts }
})
.collect();
let tools = tools.map(|tools| {
vec![GeminiTool {
function_declarations: tools
.iter()
.map(|tool| {
let mut transformed_schema =
transform_schema_for_gemini(tool.parameters_schema.clone());
// Ensure the schema has the correct structure for Gemini
match transformed_schema {
serde_json::Value::Object(ref mut obj) => {
// Gemini expects the parameters to be an object with type: "object"
if !obj.contains_key("type") {
obj.insert(
"type".to_string(),
serde_json::Value::String("object".to_string()),
);
}
// If there are no properties, add an empty object
if !obj.contains_key("properties") {
obj.insert(
"properties".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
}
// If there are no required fields, add an empty array
if !obj.contains_key("required") {
obj.insert(
"required".to_string(),
serde_json::Value::Array(vec![]),
);
}
}
_ => {
// If it's not an object, create a proper object schema
let mut schema_obj = serde_json::Map::new();
schema_obj.insert(
"type".to_string(),
serde_json::Value::String("object".to_string()),
);
schema_obj.insert(
"properties".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
schema_obj.insert(
"required".to_string(),
serde_json::Value::Array(vec![]),
);
transformed_schema = serde_json::Value::Object(schema_obj);
}
}
GeminiFunctionDeclaration {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: transformed_schema,
}
})
.collect(),
}]
});
let request = GeminiRequest {
contents,
system_instruction,
tools,
};
let stream = stream_events(&self.api_key, &self.model, request, llm_debug)
.map_err(|e| WorkerError::llm_api("gemini", e.to_string()));
Ok(Box::new(Box::pin(stream)))
}
pub async fn get_model_details(
&self,
model_name: &str,
) -> Result<crate::types::ModelInfo, WorkerError> {
let client = Client::new();
let url = UrlConfig::get_model_url("gemini", model_name);
let response = client
.get(&url)
.header("x-goog-api-key", &self.api_key)
.send()
.await
.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini)
})?;
if !response.status().is_success() {
return Err(WorkerError::from_api_error(
format!(
"Gemini model details request failed with status: {}",
response.status()
),
&crate::types::LlmProvider::Gemini,
));
}
let model_data: serde_json::Value = response.json().await.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini)
})?;
let name = model_data
.get("name")
.and_then(|n| n.as_str())
.unwrap_or(model_name);
let display_name = model_data
.get("displayName")
.and_then(|d| d.as_str())
.unwrap_or(name);
let description = model_data
.get("description")
.and_then(|d| d.as_str())
.unwrap_or("");
let version = model_data
.get("version")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let input_token_limit = model_data
.get("inputTokenLimit")
.and_then(|i| i.as_u64())
.map(|i| i as u32);
let _output_token_limit = model_data
.get("outputTokenLimit")
.and_then(|o| o.as_u64())
.map(|o| o as u32);
let empty_vec = Vec::new();
let supported_methods = model_data
.get("supportedGenerationMethods")
.and_then(|s| s.as_array())
.unwrap_or(&empty_vec);
let supports_tools = supported_methods
.iter()
.any(|method| method.as_str() == Some("generateContent"));
let supports_vision = false; // Will be determined dynamically
let capabilities = vec!["text_generation".to_string()]; // Basic default
Ok(crate::types::ModelInfo {
id: model_name.to_string(),
name: display_name.to_string(),
provider: crate::types::LlmProvider::Gemini,
supports_tools,
supports_function_calling: supports_tools,
supports_vision,
supports_multimodal: supports_vision,
context_length: input_token_limit,
training_cutoff: version,
capabilities,
description: Some(if description.is_empty() {
format!("Google Gemini model: {}", display_name)
} else {
description.to_string()
}),
})
}
pub async fn check_connection(&self) -> Result<(), WorkerError> {
// Simple connection check - try to call the API
// For now, just return OK if model is not empty
if self.model.is_empty() {
return Err(WorkerError::model_not_found("gemini", "No model specified"));
}
Ok(())
}
}
#[async_trait::async_trait]
impl LlmClientTrait for GeminiClient {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
self.chat_stream(messages, tools, llm_debug).await
}
async fn check_connection(&self) -> Result<(), WorkerError> {
self.check_connection().await
}
fn provider(&self) -> LlmProvider {
LlmProvider::Gemini
}
fn get_model_name(&self) -> String {
self.get_model_name()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_schema_transformation() {
// Test schema with various type formats including $schema
let schema = serde_json::json!({
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"id": {
"type": "uint"
},
"optional_number": {
"type": ["integer", "null"]
},
"required_string": {
"type": "string"
},
"existing_integer": {
"type": "integer"
},
"nested": {
"type": "object",
"properties": {
"count": {
"type": ["uint", "null"]
}
}
}
},
"required": ["id", "optional_number", "required_string"]
});
let transformed = transform_schema_for_gemini(schema);
// Check that the schema has the correct structure
assert_eq!(transformed["type"], "object");
assert!(transformed["properties"].is_object());
assert!(transformed["required"].is_array());
// Check that $schema key is removed
assert!(transformed.get("$schema").is_none());
// Check that 'uint' was transformed to 'integer'
assert_eq!(transformed["properties"]["id"]["type"], "integer");
assert_eq!(transformed["properties"]["id"]["format"], "int64");
// Check that array types are converted to single types
assert_eq!(
transformed["properties"]["optional_number"]["type"],
"integer"
);
assert_eq!(
transformed["properties"]["optional_number"]["format"],
"int64"
);
assert_eq!(
transformed["properties"]["nested"]["properties"]["count"]["type"],
"integer"
);
assert_eq!(
transformed["properties"]["nested"]["properties"]["count"]["format"],
"int64"
);
// Check that existing integer types also get format
assert_eq!(
transformed["properties"]["existing_integer"]["type"],
"integer"
);
assert_eq!(
transformed["properties"]["existing_integer"]["format"],
"int64"
);
// Check that required array is updated correctly (nullable properties should be removed)
let required: Vec<&str> = transformed["required"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_str().unwrap())
.collect();
assert!(required.contains(&"id"));
assert!(required.contains(&"required_string"));
assert!(!required.contains(&"optional_number")); // Should be removed because it's nullable
}
#[test]
fn test_empty_schema_transformation() {
// Test with an empty schema as would be processed in tool generation
let schema = serde_json::json!({});
let mut transformed = transform_schema_for_gemini(schema);
// Apply the same logic as in tool generation
match transformed {
serde_json::Value::Object(ref mut obj) => {
if !obj.contains_key("type") {
obj.insert(
"type".to_string(),
serde_json::Value::String("object".to_string()),
);
}
if !obj.contains_key("properties") {
obj.insert(
"properties".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
}
if !obj.contains_key("required") {
obj.insert("required".to_string(), serde_json::Value::Array(vec![]));
}
}
_ => {
let mut schema_obj = serde_json::Map::new();
schema_obj.insert(
"type".to_string(),
serde_json::Value::String("object".to_string()),
);
schema_obj.insert(
"properties".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
schema_obj.insert("required".to_string(), serde_json::Value::Array(vec![]));
transformed = serde_json::Value::Object(schema_obj);
}
}
// Should be converted to a proper object schema
assert_eq!(transformed["type"], "object");
assert!(transformed["properties"].is_object());
assert!(transformed["required"].is_array());
}
}