use crate::core::LlmClientTrait; use crate::types::WorkerError; use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall}; use crate::config::UrlConfig; use futures_util::{Stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; #[derive(Debug, Serialize)] pub(crate) struct XAIRequest { pub model: String, pub messages: Vec, #[serde(skip_serializing_if = "std::ops::Not::not")] pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct XAIMessage { pub role: String, pub content: String, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct XAIToolCall { pub id: String, #[serde(rename = "type")] pub call_type: String, pub function: XAIFunction, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct XAIFunction { pub name: String, pub arguments: String, } #[derive(Debug, Serialize, Clone)] pub struct XAITool { #[serde(rename = "type")] pub tool_type: String, pub function: XAIFunctionDef, } #[derive(Debug, Serialize, Clone)] pub struct XAIFunctionDef { pub name: String, pub description: String, pub parameters: Value, } #[derive(Debug, Deserialize)] pub(crate) struct XAIResponse { pub choices: Vec, } #[derive(Debug, Deserialize)] pub struct XAIChoice { pub message: XAIMessage, #[serde(skip_serializing_if = "Option::is_none")] pub delta: Option, } #[derive(Debug, Deserialize)] pub struct XAIDelta { pub content: Option, pub tool_calls: Option>, } #[derive(Debug, Deserialize)] pub struct XAIModel { pub id: String, pub object: String, pub created: i64, pub owned_by: String, } #[derive(Debug, Deserialize)] pub struct XAIModelsResponse { pub object: String, pub data: Vec, } pub struct XAIClient { api_key: String, model: String, } impl XAIClient { pub fn new(api_key: &str, model: &str) -> Self { Self { api_key: api_key.to_string(), model: model.to_string(), } } pub fn get_model_name(&self) -> String { self.model.clone() } } use async_stream::stream; impl XAIClient { pub async fn chat_stream<'a>( &'a self, messages: Vec, tools: Option<&[crate::types::DynamicToolDefinition]>, llm_debug: Option, ) -> Result< Box> + Unpin + Send + 'a>, WorkerError, > { let client = Client::new(); let url = UrlConfig::get_completion_url("xai"); let xai_messages: Vec = messages .into_iter() .map(|msg| XAIMessage { role: match msg.role { Role::User => "user".to_string(), Role::Model => "assistant".to_string(), Role::System => "system".to_string(), Role::Tool => "tool".to_string(), }, content: msg.content, tool_calls: None, }) .collect(); let xai_tools = tools.map(|tools| { tools .iter() .map(|tool| XAITool { tool_type: "function".to_string(), function: XAIFunctionDef { name: tool.name.clone(), description: tool.description.clone(), parameters: tool.parameters_schema.clone(), }, }) .collect() }); let request = XAIRequest { model: self.model.clone(), messages: xai_messages, stream: true, tools: xai_tools, max_tokens: None, temperature: None, }; let response = client .post(url) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", self.api_key)) .json(&request) .send() .await .map_err(|e| { WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI) })?; if !response.status().is_success() { let status = response.status(); let error_body = response.text().await.unwrap_or_default(); return Err(WorkerError::from_api_error( format!("xAI API error: {} - {}", status, error_body), &crate::types::LlmProvider::XAI, )); } let stream = stream! { if let Some(ref debug) = llm_debug { if let Some(debug_event) = debug.debug_request(&self.model, "xAI", &serde_json::to_value(&request).unwrap_or_default()) { yield Ok(debug_event); } } let mut stream = response.bytes_stream(); let mut buffer = String::new(); while let Some(chunk) = stream.next().await { match chunk { Ok(bytes) => { let chunk_str = String::from_utf8_lossy(&bytes); buffer.push_str(&chunk_str); while let Some(line_end) = buffer.find('\n') { let line = buffer[..line_end].to_string(); buffer = buffer[line_end + 1..].to_string(); if line.starts_with("data: ") { let data = &line[6..]; if data == "[DONE]" { yield Ok(StreamEvent::Completion(Message::new( Role::Model, "".to_string(), ))); break; } match serde_json::from_str::(data) { Ok(json_data) => { if let Some(ref debug) = llm_debug { if let Some(debug_event) = debug.debug_response(&self.model, "xAI", &json_data) { yield Ok(debug_event); } } if let Some(choices) = json_data.get("choices").and_then(|c| c.as_array()) { for choice in choices { if let Some(delta) = choice.get("delta") { if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { yield Ok(StreamEvent::Chunk(content.to_string())); } if let Some(tool_calls) = delta.get("tool_calls").and_then(|tc| tc.as_array()) { for tool_call in tool_calls { if let Some(function) = tool_call.get("function") { if let (Some(name), Some(arguments)) = ( function.get("name").and_then(|n| n.as_str()), function.get("arguments").and_then(|a| a.as_str()) ) { let tool_call = ToolCall { name: name.to_string(), arguments: arguments.to_string(), }; yield Ok(StreamEvent::ToolCall(tool_call)); } } } } } } } } Err(e) => { tracing::warn!("Failed to parse xAI stream response: {}", e); } } } } } Err(e) => { yield Err(WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI)); break; } } } }; Ok(Box::new(Box::pin(stream))) } pub async fn get_model_details( &self, model_id: &str, ) -> Result { let client = Client::new(); let url = UrlConfig::get_model_url("xai", model_id); let response = client .get(&url) .header("Authorization", format!("Bearer {}", self.api_key)) .send() .await .map_err(|e| { WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI) })?; if !response.status().is_success() { return Err(WorkerError::from_api_error( format!( "xAI model details request failed with status: {}", response.status() ), &crate::types::LlmProvider::XAI, )); } let model_data: XAIModel = response.json().await.map_err(|e| { WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI) })?; let supports_tools = true; // Will be determined by config let supports_vision = false; // Will be determined by config let context_length = None; // Will be determined by config let capabilities = vec!["text_generation".to_string()]; // Basic default let description = format!("xAI {} model ({})", model_data.id, model_data.owned_by); Ok(crate::types::ModelInfo { id: model_data.id.clone(), name: format!("{} ({})", model_data.id, model_data.owned_by), provider: crate::types::LlmProvider::XAI, supports_tools, supports_function_calling: supports_tools, supports_vision, supports_multimodal: supports_vision, context_length, training_cutoff: Some( chrono::DateTime::from_timestamp(model_data.created, 0) .map(|dt| dt.format("%Y-%m-%d").to_string()) .unwrap_or_else(|| "2024-12-12".to_string()), ), capabilities, description: Some(description), }) } pub async fn check_connection(&self) -> Result<(), WorkerError> { let client = Client::new(); let url = UrlConfig::get_completion_url("xai"); let test_request = XAIRequest { model: self.model.clone(), messages: vec![XAIMessage { role: "user".to_string(), content: "Hi".to_string(), tool_calls: None, }], stream: false, tools: None, max_tokens: Some(10), temperature: Some(0.1), }; let response = client .post(url) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", self.api_key)) .json(&test_request) .send() .await .map_err(|e| { WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI) })?; if !response.status().is_success() { let status = response.status(); let error_body = response.text().await.unwrap_or_default(); return Err(WorkerError::from_api_error( format!("xAI connection test failed: {} - {}", status, error_body), &crate::types::LlmProvider::XAI, )); } Ok(()) } } #[async_trait::async_trait] impl LlmClientTrait for XAIClient { async fn chat_stream<'a>( &'a self, messages: Vec, tools: Option<&[DynamicToolDefinition]>, llm_debug: Option, ) -> Result< Box> + Unpin + Send + 'a>, WorkerError, > { self.chat_stream(messages, tools, llm_debug).await } async fn check_connection(&self) -> Result<(), WorkerError> { self.check_connection().await } fn provider(&self) -> LlmProvider { LlmProvider::XAI } fn get_model_name(&self) -> String { self.get_model_name() } }