977 lines
40 KiB
Rust
977 lines
40 KiB
Rust
use crate::core::LlmClientTrait;
|
||
use crate::types::WorkerError;
|
||
use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
|
||
use crate::config::UrlConfig;
|
||
use futures_util::{Stream, StreamExt, TryStreamExt};
|
||
use reqwest::Client;
|
||
use serde::{Deserialize, Serialize};
|
||
use tracing;
|
||
|
||
/// Extract tool name from Tool message content
|
||
fn extract_tool_name_from_content(content: &str) -> Option<String> {
|
||
// Look for patterns like "Tool 'tool_name' executed successfully"
|
||
if let Some(start) = content.find("Tool '") {
|
||
if let Some(end) = content[start + 6..].find("'") {
|
||
let tool_name = &content[start + 6..start + 6 + end];
|
||
return Some(tool_name.to_string());
|
||
}
|
||
}
|
||
None
|
||
}
|
||
|
||
/// Parse tool call information from message content
|
||
|
||
/// Transforms a JSON schema to be compatible with Gemini API
|
||
/// Converts 'uint' types to 'integer' types and handles nullable types
|
||
/// Also ensures the schema is in the correct format for Gemini function parameters
|
||
fn transform_schema_for_gemini(schema: serde_json::Value) -> serde_json::Value {
|
||
match schema {
|
||
serde_json::Value::Object(mut obj) => {
|
||
// Remove $schema key as it's not needed for Gemini
|
||
obj.remove("$schema");
|
||
|
||
// Handle type field
|
||
if let Some(type_val) = obj.get("type") {
|
||
match type_val {
|
||
// Convert 'uint' to 'integer'
|
||
serde_json::Value::String(s) if s == "uint" => {
|
||
obj.insert(
|
||
"type".to_string(),
|
||
serde_json::Value::String("integer".to_string()),
|
||
);
|
||
// Add format for integer types as required by Gemini
|
||
obj.insert(
|
||
"format".to_string(),
|
||
serde_json::Value::String("int64".to_string()),
|
||
);
|
||
}
|
||
// Handle array types like ["integer", "null"]
|
||
serde_json::Value::Array(arr) => {
|
||
if let Some(non_null_type) = arr.iter().find(|&t| t != "null") {
|
||
// Use the non-null type
|
||
let mut new_type = non_null_type.clone();
|
||
// Convert 'uint' to 'integer' if needed
|
||
if let serde_json::Value::String(s) = &new_type {
|
||
if s == "uint" {
|
||
new_type = serde_json::Value::String("integer".to_string());
|
||
}
|
||
}
|
||
obj.insert("type".to_string(), new_type.clone());
|
||
|
||
// Add format for integer types as required by Gemini
|
||
if let serde_json::Value::String(type_str) = &new_type {
|
||
if type_str == "integer" {
|
||
obj.insert(
|
||
"format".to_string(),
|
||
serde_json::Value::String("int64".to_string()),
|
||
);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
// Handle existing integer types
|
||
serde_json::Value::String(s) if s == "integer" => {
|
||
// Add format for integer types as required by Gemini
|
||
obj.insert(
|
||
"format".to_string(),
|
||
serde_json::Value::String("int64".to_string()),
|
||
);
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
// Handle properties and required fields
|
||
if let (Some(properties), Some(required)) = (obj.get("properties"), obj.get("required"))
|
||
{
|
||
if let (serde_json::Value::Object(props), serde_json::Value::Array(req_arr)) =
|
||
(properties, required)
|
||
{
|
||
let mut new_required = Vec::new();
|
||
|
||
for (prop_name, _) in props {
|
||
// Only include in required if it's not nullable
|
||
if req_arr.iter().any(|r| r == prop_name) {
|
||
// Check if this property has a nullable type
|
||
if let Some(prop_schema) = props.get(prop_name) {
|
||
if let Some(prop_type) = prop_schema.get("type") {
|
||
// If type is an array containing "null", it's nullable
|
||
let is_nullable = match prop_type {
|
||
serde_json::Value::Array(arr) => {
|
||
arr.iter().any(|t| t == "null")
|
||
}
|
||
_ => false,
|
||
};
|
||
|
||
// Only add to required if not nullable
|
||
if !is_nullable {
|
||
new_required
|
||
.push(serde_json::Value::String(prop_name.clone()));
|
||
}
|
||
} else {
|
||
// No type info, assume required
|
||
new_required.push(serde_json::Value::String(prop_name.clone()));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
obj.insert(
|
||
"required".to_string(),
|
||
serde_json::Value::Array(new_required),
|
||
);
|
||
}
|
||
}
|
||
|
||
// Recursively transform nested objects
|
||
for (_, value) in obj.iter_mut() {
|
||
*value = transform_schema_for_gemini(value.clone());
|
||
}
|
||
|
||
serde_json::Value::Object(obj)
|
||
}
|
||
serde_json::Value::Array(arr) => {
|
||
serde_json::Value::Array(arr.into_iter().map(transform_schema_for_gemini).collect())
|
||
}
|
||
other => other,
|
||
}
|
||
}
|
||
|
||
// --- Request Structures ---
|
||
#[derive(Debug, Serialize, Clone)]
|
||
pub struct GeminiTool {
|
||
#[serde(rename = "functionDeclarations")]
|
||
pub function_declarations: Vec<GeminiFunctionDeclaration>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Clone)]
|
||
pub struct GeminiFunctionDeclaration {
|
||
pub name: String,
|
||
pub description: String,
|
||
pub parameters: serde_json::Value,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Clone)]
|
||
pub struct GeminiRequest {
|
||
pub contents: Vec<GeminiContent>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
#[serde(rename = "systemInstruction")]
|
||
pub system_instruction: Option<GeminiContent>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub tools: Option<Vec<GeminiTool>>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||
pub struct GeminiContent {
|
||
pub role: String,
|
||
#[serde(default)]
|
||
pub parts: Vec<GeminiPart>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||
#[serde(untagged)]
|
||
pub enum GeminiPart {
|
||
Text {
|
||
text: String,
|
||
},
|
||
FunctionCall {
|
||
#[serde(rename = "functionCall")]
|
||
function_call: GeminiFunctionCall,
|
||
},
|
||
FunctionResponse {
|
||
#[serde(rename = "functionResponse")]
|
||
function_response: GeminiFunctionResponse,
|
||
},
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||
pub struct GeminiFunctionCall {
|
||
pub name: String,
|
||
pub args: serde_json::Value,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||
pub struct GeminiFunctionResponse {
|
||
pub name: String,
|
||
pub response: serde_json::Value,
|
||
}
|
||
|
||
// --- Response Structures ---
|
||
#[derive(Debug, Deserialize, Clone)]
|
||
#[serde(rename_all = "camelCase")]
|
||
pub struct GeminiResponse {
|
||
#[serde(default)]
|
||
pub candidates: Vec<GeminiCandidate>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize, Clone)]
|
||
#[serde(rename_all = "camelCase")]
|
||
pub struct GeminiCandidate {
|
||
pub content: GeminiContent,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub finish_reason: Option<String>,
|
||
}
|
||
|
||
fn build_url(model: &str) -> String {
|
||
let base_url = UrlConfig::get_base_url("gemini");
|
||
let action = "streamGenerateContent";
|
||
format!("{}/v1beta/models/{}:{}", base_url, model, action)
|
||
}
|
||
|
||
/// Finds the start and end indices of the first complete JSON object `{...}` in the buffer.
|
||
fn find_first_json_object_bounds(buffer: &[u8]) -> Option<(usize, usize)> {
|
||
if let Some(start) = buffer.iter().position(|&b| b == b'{') {
|
||
let mut brace_count = 0;
|
||
let mut in_string = false;
|
||
let mut escaped = false;
|
||
|
||
for (i, &byte) in buffer.iter().skip(start).enumerate() {
|
||
if in_string {
|
||
if escaped {
|
||
escaped = false;
|
||
} else if byte == b'\\' {
|
||
escaped = true;
|
||
} else if byte == b'"' {
|
||
in_string = false;
|
||
}
|
||
} else {
|
||
match byte {
|
||
b'"' => in_string = true,
|
||
b'{' => brace_count += 1,
|
||
b'}' => {
|
||
brace_count -= 1;
|
||
if brace_count == 0 {
|
||
let end = start + i + 1;
|
||
return Some((start, end));
|
||
}
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
None // No complete object found
|
||
}
|
||
|
||
/// Completes a chat request with streaming, yielding StreamEvent objects.
|
||
pub(crate) fn stream_events<'a>(
|
||
api_key: &'a str,
|
||
model: &'a str,
|
||
request: GeminiRequest,
|
||
llm_debug: Option<crate::types::LlmDebug>,
|
||
) -> impl Stream<Item = anyhow::Result<StreamEvent>> + 'a {
|
||
let api_key = api_key.to_string();
|
||
let model = model.to_string();
|
||
|
||
async_stream::try_stream! {
|
||
let body = serde_json::to_string_pretty(&request).unwrap_or_else(|e| e.to_string());
|
||
tracing::debug!("Gemini Request Body: {}", body);
|
||
|
||
if let Some(debug_settings) = &llm_debug {
|
||
if let Some(debug_event) = debug_settings.debug_request(&model, "Gemini", &serde_json::to_value(&request).unwrap_or_default()) {
|
||
yield debug_event;
|
||
}
|
||
}
|
||
|
||
let client = Client::new();
|
||
let url = build_url(&model);
|
||
|
||
let response = client
|
||
.post(&url)
|
||
.header("x-goog-api-key", &api_key)
|
||
.json(&request)
|
||
.send()
|
||
.await
|
||
.map_err(|e| anyhow::anyhow!("Gemini API request failed: {}", e))?;
|
||
|
||
let status = response.status();
|
||
if !status.is_success() {
|
||
let error_body = response.text().await.unwrap_or_else(|_| "Could not read error body".to_string());
|
||
let error_msg = format!("Gemini API request failed with status: {} - {}", status, error_body);
|
||
tracing::error!("{}", error_msg);
|
||
Err(anyhow::anyhow!(error_msg))?;
|
||
} else {
|
||
let mut byte_stream = response.bytes_stream();
|
||
let mut buffer = Vec::new();
|
||
let mut full_content = String::new();
|
||
|
||
while let Some(chunk_result) = byte_stream.next().await {
|
||
let chunk = chunk_result?;
|
||
buffer.extend_from_slice(&chunk);
|
||
|
||
while let Some((start, end)) = find_first_json_object_bounds(&buffer) {
|
||
let json_slice = &buffer[start..end];
|
||
|
||
if let Some(debug_settings) = &llm_debug {
|
||
if let Ok(response_value) = serde_json::from_slice::<serde_json::Value>(json_slice) {
|
||
if let Some(debug_event) = debug_settings.debug_response(&model, "Gemini", &response_value) {
|
||
yield debug_event;
|
||
}
|
||
}
|
||
}
|
||
|
||
match serde_json::from_slice::<GeminiResponse>(json_slice) {
|
||
Ok(response) => {
|
||
let response_text = String::from_utf8_lossy(json_slice);
|
||
tracing::debug!(
|
||
response = %response_text,
|
||
candidates_count = response.candidates.len(),
|
||
"Successfully parsed Gemini response"
|
||
);
|
||
if response.candidates.is_empty() {
|
||
tracing::warn!(
|
||
response = %response_text,
|
||
"Received empty candidates in Gemini response"
|
||
);
|
||
} else if let Some(candidate) = response.candidates.get(0) {
|
||
// Log finish reason for debugging
|
||
if let Some(ref finish_reason) = candidate.finish_reason {
|
||
tracing::debug!(
|
||
finish_reason = %finish_reason,
|
||
"Received finish reason in Gemini response"
|
||
);
|
||
|
||
// Handle specific finish reasons
|
||
match finish_reason.as_str() {
|
||
"STOP" => {
|
||
tracing::debug!("Gemini response completed with STOP");
|
||
// Continue processing parts if any, this is normal completion
|
||
}
|
||
"MAX_TOKENS" => {
|
||
tracing::warn!("Gemini response stopped due to MAX_TOKENS");
|
||
}
|
||
"SAFETY" => {
|
||
tracing::warn!("Gemini response stopped due to SAFETY concerns");
|
||
}
|
||
"RECITATION" => {
|
||
tracing::warn!("Gemini response stopped due to RECITATION");
|
||
}
|
||
other => {
|
||
tracing::warn!("Gemini response stopped with unknown reason: {}", other);
|
||
}
|
||
}
|
||
}
|
||
|
||
if candidate.content.parts.is_empty() {
|
||
tracing::warn!(
|
||
response = %response_text,
|
||
role = %candidate.content.role,
|
||
finish_reason = ?candidate.finish_reason,
|
||
"Received empty parts in Gemini response"
|
||
);
|
||
} else {
|
||
for part in &candidate.content.parts {
|
||
tracing::debug!("Processing Gemini part (type unknown)");
|
||
match part {
|
||
GeminiPart::Text { text } => {
|
||
tracing::debug!("Found Text part with content length: {}", text.len());
|
||
full_content.push_str(text);
|
||
yield StreamEvent::Chunk(text.clone());
|
||
}
|
||
GeminiPart::FunctionCall { function_call } => {
|
||
tracing::debug!("Found FunctionCall part: name={}, args={:?}", function_call.name, function_call.args);
|
||
let tool_call = ToolCall {
|
||
name: function_call.name.clone(),
|
||
arguments: serde_json::to_string(&function_call.args)
|
||
.unwrap_or_else(|_| "{}".to_string()),
|
||
};
|
||
yield StreamEvent::ToolCall(tool_call.clone());
|
||
}
|
||
GeminiPart::FunctionResponse { .. } => {
|
||
// Function responses in model output are not expected
|
||
// as they're part of the input conversation history
|
||
tracing::warn!("Unexpected FunctionResponse in model output");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
Err(e) => {
|
||
let response_text = String::from_utf8_lossy(json_slice);
|
||
tracing::warn!(
|
||
error = %e,
|
||
response = %response_text,
|
||
"Failed to deserialize GeminiResponse from slice"
|
||
);
|
||
}
|
||
}
|
||
buffer.drain(..end);
|
||
}
|
||
}
|
||
|
||
let final_message = Message::new(
|
||
Role::Model,
|
||
full_content,
|
||
);
|
||
yield StreamEvent::Completion(final_message);
|
||
}
|
||
}
|
||
}
|
||
|
||
pub struct GeminiClient {
|
||
api_key: String,
|
||
model: String,
|
||
}
|
||
|
||
impl GeminiClient {
|
||
pub fn new(api_key: &str, model: &str) -> Self {
|
||
Self {
|
||
api_key: api_key.to_string(),
|
||
model: model.to_string(),
|
||
}
|
||
}
|
||
|
||
pub fn get_model_name(&self) -> String {
|
||
self.model.clone()
|
||
}
|
||
|
||
/// 静的メソッド:API キーを受け取ってモデル一覧を取得
|
||
pub async fn list_models_static(
|
||
api_key: &str,
|
||
) -> Result<Vec<crate::types::ModelInfo>, WorkerError> {
|
||
let client = Client::new();
|
||
let url = UrlConfig::get_models_url("gemini");
|
||
|
||
let response = client
|
||
.get(url)
|
||
.header("x-goog-api-key", api_key)
|
||
.send()
|
||
.await
|
||
.map_err(|e| {
|
||
tracing::error!("Gemini API request failed: {}", e);
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini)
|
||
})?;
|
||
|
||
if !response.status().is_success() {
|
||
let status = response.status();
|
||
let error_body = response.text().await.unwrap_or_default();
|
||
tracing::error!(
|
||
"Gemini list_models_static failed - Status: {}, Body: {}",
|
||
status,
|
||
error_body
|
||
);
|
||
return Err(WorkerError::from_api_error(
|
||
format!("Failed to list Gemini models: {} - {}", status, error_body),
|
||
&crate::types::LlmProvider::Gemini,
|
||
));
|
||
}
|
||
|
||
let models_response: serde_json::Value = response.json().await.map_err(|e| {
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini)
|
||
})?;
|
||
|
||
let mut models = Vec::new();
|
||
|
||
if let Some(models_array) = models_response.get("models").and_then(|m| m.as_array()) {
|
||
for model in models_array {
|
||
if let Some(name) = model.get("name").and_then(|n| n.as_str()) {
|
||
// "models/" プレフィックスを除去
|
||
let model_id = name.strip_prefix("models/").unwrap_or(name);
|
||
|
||
// generateContentメソッドをサポートするモデルのみを含める
|
||
if let Some(supported_methods) = model
|
||
.get("supportedGenerationMethods")
|
||
.and_then(|m| m.as_array())
|
||
{
|
||
let supports_generate_content = supported_methods
|
||
.iter()
|
||
.any(|method| method.as_str() == Some("generateContent"));
|
||
|
||
if supports_generate_content {
|
||
models.push(crate::types::ModelInfo {
|
||
id: model_id.to_string(),
|
||
name: model
|
||
.get("displayName")
|
||
.and_then(|d| d.as_str())
|
||
.unwrap_or(model_id)
|
||
.to_string(),
|
||
provider: crate::types::LlmProvider::Gemini,
|
||
supports_tools: true,
|
||
supports_function_calling: true,
|
||
supports_vision: false,
|
||
supports_multimodal: false,
|
||
context_length: None,
|
||
training_cutoff: None,
|
||
capabilities: vec!["text_generation".to_string()],
|
||
description: model
|
||
.get("description")
|
||
.and_then(|d| d.as_str())
|
||
.map(|s| s.to_string())
|
||
.or_else(|| Some(format!("Google Gemini model: {}", model_id))),
|
||
});
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
tracing::info!(
|
||
"Gemini list_models_static found {} models with metadata",
|
||
models.len()
|
||
);
|
||
Ok(models)
|
||
}
|
||
|
||
pub async fn chat_stream<'a>(
|
||
&'a self,
|
||
messages: Vec<Message>,
|
||
tools: Option<&[DynamicToolDefinition]>,
|
||
llm_debug: Option<crate::types::LlmDebug>,
|
||
) -> Result<
|
||
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
|
||
WorkerError,
|
||
> {
|
||
// Separate system messages from regular messages
|
||
let (system_messages, regular_messages): (Vec<_>, Vec<_>) = messages
|
||
.into_iter()
|
||
.partition(|msg| matches!(msg.role, Role::System));
|
||
|
||
// Create system instruction from system messages
|
||
let system_instruction = if !system_messages.is_empty() {
|
||
let combined_system_content = system_messages
|
||
.into_iter()
|
||
.map(|msg| msg.content)
|
||
.collect::<Vec<_>>()
|
||
.join("\n\n");
|
||
|
||
Some(GeminiContent {
|
||
role: "user".to_string(), // System instruction uses "user" role
|
||
parts: vec![GeminiPart::Text {
|
||
text: combined_system_content,
|
||
}],
|
||
})
|
||
} else {
|
||
None
|
||
};
|
||
|
||
// Process regular messages with proper tool context handling
|
||
let contents = regular_messages
|
||
.into_iter()
|
||
.map(|msg| {
|
||
let (role, parts) = match msg.role {
|
||
Role::User => (
|
||
"user".to_string(),
|
||
vec![GeminiPart::Text { text: msg.content }],
|
||
),
|
||
Role::Model => {
|
||
if let Some(tool_calls) = &msg.tool_calls {
|
||
// Model message with tool calls - convert to FunctionCall parts
|
||
tracing::debug!(
|
||
"Converting model message with {} tool calls to FunctionCall parts",
|
||
tool_calls.len()
|
||
);
|
||
let mut parts = Vec::new();
|
||
|
||
// Add text content if present
|
||
if !msg.content.is_empty() {
|
||
parts.push(GeminiPart::Text {
|
||
text: msg.content.clone(),
|
||
});
|
||
}
|
||
|
||
// Add function calls
|
||
for tool_call in tool_calls {
|
||
tracing::debug!(
|
||
"Adding FunctionCall part for tool: {}",
|
||
tool_call.name
|
||
);
|
||
let args = serde_json::from_str(&tool_call.arguments)
|
||
.unwrap_or(serde_json::json!({}));
|
||
parts.push(GeminiPart::FunctionCall {
|
||
function_call: GeminiFunctionCall {
|
||
name: tool_call.name.clone(),
|
||
args,
|
||
},
|
||
});
|
||
}
|
||
|
||
("model".to_string(), parts)
|
||
} else {
|
||
// Regular model message
|
||
tracing::debug!("Converting regular model message (no tool calls)");
|
||
(
|
||
"model".to_string(),
|
||
vec![GeminiPart::Text { text: msg.content }],
|
||
)
|
||
}
|
||
}
|
||
Role::Tool => {
|
||
// Tool responses should be sent as FunctionResponse
|
||
if let Some(tool_name) = extract_tool_name_from_content(&msg.content) {
|
||
// Extract result from the content
|
||
let result_value = if msg.content.contains("Result: ") {
|
||
if let Some(result_start) = msg.content.find("Result: ") {
|
||
let result_str = &msg.content[result_start + 8..];
|
||
// Try to parse as JSON, fallback to string
|
||
serde_json::from_str(result_str)
|
||
.unwrap_or_else(|_| serde_json::json!(result_str))
|
||
} else {
|
||
serde_json::json!(msg.content)
|
||
}
|
||
} else {
|
||
serde_json::json!(msg.content)
|
||
};
|
||
|
||
(
|
||
"user".to_string(),
|
||
vec![GeminiPart::FunctionResponse {
|
||
function_response: GeminiFunctionResponse {
|
||
name: tool_name,
|
||
response: result_value,
|
||
},
|
||
}],
|
||
)
|
||
} else {
|
||
// Fallback to text response if tool name can't be extracted
|
||
(
|
||
"user".to_string(),
|
||
vec![GeminiPart::Text {
|
||
text: format!("Tool Response:\n{}", msg.content),
|
||
}],
|
||
)
|
||
}
|
||
}
|
||
Role::System => unreachable!(), // Should not reach here after partition
|
||
};
|
||
GeminiContent { role, parts }
|
||
})
|
||
.collect();
|
||
|
||
let tools = tools.map(|tools| {
|
||
vec![GeminiTool {
|
||
function_declarations: tools
|
||
.iter()
|
||
.map(|tool| {
|
||
let mut transformed_schema =
|
||
transform_schema_for_gemini(tool.parameters_schema.clone());
|
||
|
||
// Ensure the schema has the correct structure for Gemini
|
||
match transformed_schema {
|
||
serde_json::Value::Object(ref mut obj) => {
|
||
// Gemini expects the parameters to be an object with type: "object"
|
||
if !obj.contains_key("type") {
|
||
obj.insert(
|
||
"type".to_string(),
|
||
serde_json::Value::String("object".to_string()),
|
||
);
|
||
}
|
||
// If there are no properties, add an empty object
|
||
if !obj.contains_key("properties") {
|
||
obj.insert(
|
||
"properties".to_string(),
|
||
serde_json::Value::Object(serde_json::Map::new()),
|
||
);
|
||
}
|
||
// If there are no required fields, add an empty array
|
||
if !obj.contains_key("required") {
|
||
obj.insert(
|
||
"required".to_string(),
|
||
serde_json::Value::Array(vec![]),
|
||
);
|
||
}
|
||
}
|
||
_ => {
|
||
// If it's not an object, create a proper object schema
|
||
let mut schema_obj = serde_json::Map::new();
|
||
schema_obj.insert(
|
||
"type".to_string(),
|
||
serde_json::Value::String("object".to_string()),
|
||
);
|
||
schema_obj.insert(
|
||
"properties".to_string(),
|
||
serde_json::Value::Object(serde_json::Map::new()),
|
||
);
|
||
schema_obj.insert(
|
||
"required".to_string(),
|
||
serde_json::Value::Array(vec![]),
|
||
);
|
||
transformed_schema = serde_json::Value::Object(schema_obj);
|
||
}
|
||
}
|
||
|
||
GeminiFunctionDeclaration {
|
||
name: tool.name.clone(),
|
||
description: tool.description.clone(),
|
||
parameters: transformed_schema,
|
||
}
|
||
})
|
||
.collect(),
|
||
}]
|
||
});
|
||
|
||
let request = GeminiRequest {
|
||
contents,
|
||
system_instruction,
|
||
tools,
|
||
};
|
||
|
||
let stream = stream_events(&self.api_key, &self.model, request, llm_debug)
|
||
.map_err(|e| WorkerError::llm_api("gemini", e.to_string()));
|
||
|
||
Ok(Box::new(Box::pin(stream)))
|
||
}
|
||
|
||
pub async fn get_model_details(
|
||
&self,
|
||
model_name: &str,
|
||
) -> Result<crate::types::ModelInfo, WorkerError> {
|
||
let client = Client::new();
|
||
let url = UrlConfig::get_model_url("gemini", model_name);
|
||
|
||
let response = client
|
||
.get(&url)
|
||
.header("x-goog-api-key", &self.api_key)
|
||
.send()
|
||
.await
|
||
.map_err(|e| {
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini)
|
||
})?;
|
||
|
||
if !response.status().is_success() {
|
||
return Err(WorkerError::from_api_error(
|
||
format!(
|
||
"Gemini model details request failed with status: {}",
|
||
response.status()
|
||
),
|
||
&crate::types::LlmProvider::Gemini,
|
||
));
|
||
}
|
||
|
||
let model_data: serde_json::Value = response.json().await.map_err(|e| {
|
||
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Gemini)
|
||
})?;
|
||
|
||
let name = model_data
|
||
.get("name")
|
||
.and_then(|n| n.as_str())
|
||
.unwrap_or(model_name);
|
||
let display_name = model_data
|
||
.get("displayName")
|
||
.and_then(|d| d.as_str())
|
||
.unwrap_or(name);
|
||
let description = model_data
|
||
.get("description")
|
||
.and_then(|d| d.as_str())
|
||
.unwrap_or("");
|
||
let version = model_data
|
||
.get("version")
|
||
.and_then(|v| v.as_str())
|
||
.map(|s| s.to_string());
|
||
|
||
let input_token_limit = model_data
|
||
.get("inputTokenLimit")
|
||
.and_then(|i| i.as_u64())
|
||
.map(|i| i as u32);
|
||
let _output_token_limit = model_data
|
||
.get("outputTokenLimit")
|
||
.and_then(|o| o.as_u64())
|
||
.map(|o| o as u32);
|
||
|
||
let empty_vec = Vec::new();
|
||
let supported_methods = model_data
|
||
.get("supportedGenerationMethods")
|
||
.and_then(|s| s.as_array())
|
||
.unwrap_or(&empty_vec);
|
||
|
||
let supports_tools = supported_methods
|
||
.iter()
|
||
.any(|method| method.as_str() == Some("generateContent"));
|
||
|
||
let supports_vision = false; // Will be determined dynamically
|
||
let capabilities = vec!["text_generation".to_string()]; // Basic default
|
||
|
||
Ok(crate::types::ModelInfo {
|
||
id: model_name.to_string(),
|
||
name: display_name.to_string(),
|
||
provider: crate::types::LlmProvider::Gemini,
|
||
supports_tools,
|
||
supports_function_calling: supports_tools,
|
||
supports_vision,
|
||
supports_multimodal: supports_vision,
|
||
context_length: input_token_limit,
|
||
training_cutoff: version,
|
||
capabilities,
|
||
description: Some(if description.is_empty() {
|
||
format!("Google Gemini model: {}", display_name)
|
||
} else {
|
||
description.to_string()
|
||
}),
|
||
})
|
||
}
|
||
|
||
pub async fn check_connection(&self) -> Result<(), WorkerError> {
|
||
// Simple connection check - try to call the API
|
||
// For now, just return OK if model is not empty
|
||
if self.model.is_empty() {
|
||
return Err(WorkerError::model_not_found("gemini", "No model specified"));
|
||
}
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl LlmClientTrait for GeminiClient {
|
||
async fn chat_stream<'a>(
|
||
&'a self,
|
||
messages: Vec<Message>,
|
||
tools: Option<&[DynamicToolDefinition]>,
|
||
llm_debug: Option<crate::types::LlmDebug>,
|
||
) -> Result<
|
||
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
|
||
WorkerError,
|
||
> {
|
||
self.chat_stream(messages, tools, llm_debug).await
|
||
}
|
||
|
||
async fn check_connection(&self) -> Result<(), WorkerError> {
|
||
self.check_connection().await
|
||
}
|
||
|
||
fn provider(&self) -> LlmProvider {
|
||
LlmProvider::Gemini
|
||
}
|
||
|
||
fn get_model_name(&self) -> String {
|
||
self.get_model_name()
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use std::env;
|
||
|
||
#[test]
|
||
fn test_schema_transformation() {
|
||
// Test schema with various type formats including $schema
|
||
let schema = serde_json::json!({
|
||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||
"type": "object",
|
||
"properties": {
|
||
"id": {
|
||
"type": "uint"
|
||
},
|
||
"optional_number": {
|
||
"type": ["integer", "null"]
|
||
},
|
||
"required_string": {
|
||
"type": "string"
|
||
},
|
||
"existing_integer": {
|
||
"type": "integer"
|
||
},
|
||
"nested": {
|
||
"type": "object",
|
||
"properties": {
|
||
"count": {
|
||
"type": ["uint", "null"]
|
||
}
|
||
}
|
||
}
|
||
},
|
||
"required": ["id", "optional_number", "required_string"]
|
||
});
|
||
|
||
let transformed = transform_schema_for_gemini(schema);
|
||
|
||
// Check that the schema has the correct structure
|
||
assert_eq!(transformed["type"], "object");
|
||
assert!(transformed["properties"].is_object());
|
||
assert!(transformed["required"].is_array());
|
||
|
||
// Check that $schema key is removed
|
||
assert!(transformed.get("$schema").is_none());
|
||
|
||
// Check that 'uint' was transformed to 'integer'
|
||
assert_eq!(transformed["properties"]["id"]["type"], "integer");
|
||
assert_eq!(transformed["properties"]["id"]["format"], "int64");
|
||
|
||
// Check that array types are converted to single types
|
||
assert_eq!(
|
||
transformed["properties"]["optional_number"]["type"],
|
||
"integer"
|
||
);
|
||
assert_eq!(
|
||
transformed["properties"]["optional_number"]["format"],
|
||
"int64"
|
||
);
|
||
assert_eq!(
|
||
transformed["properties"]["nested"]["properties"]["count"]["type"],
|
||
"integer"
|
||
);
|
||
assert_eq!(
|
||
transformed["properties"]["nested"]["properties"]["count"]["format"],
|
||
"int64"
|
||
);
|
||
|
||
// Check that existing integer types also get format
|
||
assert_eq!(
|
||
transformed["properties"]["existing_integer"]["type"],
|
||
"integer"
|
||
);
|
||
assert_eq!(
|
||
transformed["properties"]["existing_integer"]["format"],
|
||
"int64"
|
||
);
|
||
|
||
// Check that required array is updated correctly (nullable properties should be removed)
|
||
let required: Vec<&str> = transformed["required"]
|
||
.as_array()
|
||
.unwrap()
|
||
.iter()
|
||
.map(|v| v.as_str().unwrap())
|
||
.collect();
|
||
|
||
assert!(required.contains(&"id"));
|
||
assert!(required.contains(&"required_string"));
|
||
assert!(!required.contains(&"optional_number")); // Should be removed because it's nullable
|
||
}
|
||
|
||
#[test]
|
||
fn test_empty_schema_transformation() {
|
||
// Test with an empty schema as would be processed in tool generation
|
||
let schema = serde_json::json!({});
|
||
let mut transformed = transform_schema_for_gemini(schema);
|
||
|
||
// Apply the same logic as in tool generation
|
||
match transformed {
|
||
serde_json::Value::Object(ref mut obj) => {
|
||
if !obj.contains_key("type") {
|
||
obj.insert(
|
||
"type".to_string(),
|
||
serde_json::Value::String("object".to_string()),
|
||
);
|
||
}
|
||
if !obj.contains_key("properties") {
|
||
obj.insert(
|
||
"properties".to_string(),
|
||
serde_json::Value::Object(serde_json::Map::new()),
|
||
);
|
||
}
|
||
if !obj.contains_key("required") {
|
||
obj.insert("required".to_string(), serde_json::Value::Array(vec![]));
|
||
}
|
||
}
|
||
_ => {
|
||
let mut schema_obj = serde_json::Map::new();
|
||
schema_obj.insert(
|
||
"type".to_string(),
|
||
serde_json::Value::String("object".to_string()),
|
||
);
|
||
schema_obj.insert(
|
||
"properties".to_string(),
|
||
serde_json::Value::Object(serde_json::Map::new()),
|
||
);
|
||
schema_obj.insert("required".to_string(), serde_json::Value::Array(vec![]));
|
||
transformed = serde_json::Value::Object(schema_obj);
|
||
}
|
||
}
|
||
|
||
// Should be converted to a proper object schema
|
||
assert_eq!(transformed["type"], "object");
|
||
assert!(transformed["properties"].is_object());
|
||
assert!(transformed["required"].is_array());
|
||
}
|
||
}
|