386 lines
14 KiB
Rust
386 lines
14 KiB
Rust
use crate::core::LlmClientTrait;
|
|
use crate::types::WorkerError;
|
|
use worker_types::{DynamicToolDefinition, LlmProvider, Message, Role, StreamEvent, ToolCall};
|
|
use crate::config::UrlConfig;
|
|
use futures_util::{Stream, StreamExt};
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub(crate) struct XAIRequest {
|
|
pub model: String,
|
|
pub messages: Vec<XAIMessage>,
|
|
#[serde(skip_serializing_if = "std::ops::Not::not")]
|
|
pub stream: bool,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tools: Option<Vec<XAITool>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub max_tokens: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub temperature: Option<f32>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct XAIMessage {
|
|
pub role: String,
|
|
pub content: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_calls: Option<Vec<XAIToolCall>>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct XAIToolCall {
|
|
pub id: String,
|
|
#[serde(rename = "type")]
|
|
pub call_type: String,
|
|
pub function: XAIFunction,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct XAIFunction {
|
|
pub name: String,
|
|
pub arguments: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Clone)]
|
|
pub struct XAITool {
|
|
#[serde(rename = "type")]
|
|
pub tool_type: String,
|
|
pub function: XAIFunctionDef,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Clone)]
|
|
pub struct XAIFunctionDef {
|
|
pub name: String,
|
|
pub description: String,
|
|
pub parameters: Value,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub(crate) struct XAIResponse {
|
|
pub choices: Vec<XAIChoice>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct XAIChoice {
|
|
pub message: XAIMessage,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub delta: Option<XAIDelta>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct XAIDelta {
|
|
pub content: Option<String>,
|
|
pub tool_calls: Option<Vec<XAIToolCall>>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct XAIModel {
|
|
pub id: String,
|
|
pub object: String,
|
|
pub created: i64,
|
|
pub owned_by: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct XAIModelsResponse {
|
|
pub object: String,
|
|
pub data: Vec<XAIModel>,
|
|
}
|
|
|
|
pub struct XAIClient {
|
|
api_key: String,
|
|
model: String,
|
|
}
|
|
|
|
impl XAIClient {
|
|
pub fn new(api_key: &str, model: &str) -> Self {
|
|
Self {
|
|
api_key: api_key.to_string(),
|
|
model: model.to_string(),
|
|
}
|
|
}
|
|
|
|
pub fn get_model_name(&self) -> String {
|
|
self.model.clone()
|
|
}
|
|
}
|
|
|
|
use async_stream::stream;
|
|
|
|
impl XAIClient {
|
|
pub async fn chat_stream<'a>(
|
|
&'a self,
|
|
messages: Vec<Message>,
|
|
tools: Option<&[crate::types::DynamicToolDefinition]>,
|
|
llm_debug: Option<crate::types::LlmDebug>,
|
|
) -> Result<
|
|
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
|
|
WorkerError,
|
|
> {
|
|
let client = Client::new();
|
|
let url = UrlConfig::get_completion_url("xai");
|
|
|
|
let xai_messages: Vec<XAIMessage> = messages
|
|
.into_iter()
|
|
.map(|msg| XAIMessage {
|
|
role: match msg.role {
|
|
Role::User => "user".to_string(),
|
|
Role::Model => "assistant".to_string(),
|
|
Role::System => "system".to_string(),
|
|
Role::Tool => "tool".to_string(),
|
|
},
|
|
content: msg.content,
|
|
tool_calls: None,
|
|
})
|
|
.collect();
|
|
|
|
let xai_tools = tools.map(|tools| {
|
|
tools
|
|
.iter()
|
|
.map(|tool| XAITool {
|
|
tool_type: "function".to_string(),
|
|
function: XAIFunctionDef {
|
|
name: tool.name.clone(),
|
|
description: tool.description.clone(),
|
|
parameters: tool.parameters_schema.clone(),
|
|
},
|
|
})
|
|
.collect()
|
|
});
|
|
|
|
let request = XAIRequest {
|
|
model: self.model.clone(),
|
|
messages: xai_messages,
|
|
stream: true,
|
|
tools: xai_tools,
|
|
max_tokens: None,
|
|
temperature: None,
|
|
};
|
|
|
|
let response = client
|
|
.post(url)
|
|
.header("Content-Type", "application/json")
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.json(&request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| {
|
|
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI)
|
|
})?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let error_body = response.text().await.unwrap_or_default();
|
|
return Err(WorkerError::from_api_error(
|
|
format!("xAI API error: {} - {}", status, error_body),
|
|
&crate::types::LlmProvider::XAI,
|
|
));
|
|
}
|
|
|
|
let stream = stream! {
|
|
if let Some(ref debug) = llm_debug {
|
|
if let Some(debug_event) = debug.debug_request(&self.model, "xAI", &serde_json::to_value(&request).unwrap_or_default()) {
|
|
yield Ok(debug_event);
|
|
}
|
|
}
|
|
|
|
let mut stream = response.bytes_stream();
|
|
let mut buffer = String::new();
|
|
|
|
while let Some(chunk) = stream.next().await {
|
|
match chunk {
|
|
Ok(bytes) => {
|
|
let chunk_str = String::from_utf8_lossy(&bytes);
|
|
buffer.push_str(&chunk_str);
|
|
|
|
while let Some(line_end) = buffer.find('\n') {
|
|
let line = buffer[..line_end].to_string();
|
|
buffer = buffer[line_end + 1..].to_string();
|
|
|
|
if line.starts_with("data: ") {
|
|
let data = &line[6..];
|
|
if data == "[DONE]" {
|
|
yield Ok(StreamEvent::Completion(Message::new(
|
|
Role::Model,
|
|
"".to_string(),
|
|
)));
|
|
break;
|
|
}
|
|
|
|
match serde_json::from_str::<Value>(data) {
|
|
Ok(json_data) => {
|
|
if let Some(ref debug) = llm_debug {
|
|
if let Some(debug_event) = debug.debug_response(&self.model, "xAI", &json_data) {
|
|
yield Ok(debug_event);
|
|
}
|
|
}
|
|
if let Some(choices) = json_data.get("choices").and_then(|c| c.as_array()) {
|
|
for choice in choices {
|
|
if let Some(delta) = choice.get("delta") {
|
|
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
|
yield Ok(StreamEvent::Chunk(content.to_string()));
|
|
}
|
|
|
|
if let Some(tool_calls) = delta.get("tool_calls").and_then(|tc| tc.as_array()) {
|
|
for tool_call in tool_calls {
|
|
if let Some(function) = tool_call.get("function") {
|
|
if let (Some(name), Some(arguments)) = (
|
|
function.get("name").and_then(|n| n.as_str()),
|
|
function.get("arguments").and_then(|a| a.as_str())
|
|
) {
|
|
let tool_call = ToolCall {
|
|
name: name.to_string(),
|
|
arguments: arguments.to_string(),
|
|
};
|
|
yield Ok(StreamEvent::ToolCall(tool_call));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!("Failed to parse xAI stream response: {}", e);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
yield Err(WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI));
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(Box::new(Box::pin(stream)))
|
|
}
|
|
|
|
pub async fn get_model_details(
|
|
&self,
|
|
model_id: &str,
|
|
) -> Result<crate::types::ModelInfo, WorkerError> {
|
|
let client = Client::new();
|
|
let url = UrlConfig::get_model_url("xai", model_id);
|
|
|
|
let response = client
|
|
.get(&url)
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.send()
|
|
.await
|
|
.map_err(|e| {
|
|
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI)
|
|
})?;
|
|
|
|
if !response.status().is_success() {
|
|
return Err(WorkerError::from_api_error(
|
|
format!(
|
|
"xAI model details request failed with status: {}",
|
|
response.status()
|
|
),
|
|
&crate::types::LlmProvider::XAI,
|
|
));
|
|
}
|
|
|
|
let model_data: XAIModel = response.json().await.map_err(|e| {
|
|
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI)
|
|
})?;
|
|
|
|
let supports_tools = true; // Will be determined by config
|
|
let supports_vision = false; // Will be determined by config
|
|
let context_length = None; // Will be determined by config
|
|
let capabilities = vec!["text_generation".to_string()]; // Basic default
|
|
let description = format!("xAI {} model ({})", model_data.id, model_data.owned_by);
|
|
|
|
Ok(crate::types::ModelInfo {
|
|
id: model_data.id.clone(),
|
|
name: format!("{} ({})", model_data.id, model_data.owned_by),
|
|
provider: crate::types::LlmProvider::XAI,
|
|
supports_tools,
|
|
supports_function_calling: supports_tools,
|
|
supports_vision,
|
|
supports_multimodal: supports_vision,
|
|
context_length,
|
|
training_cutoff: Some(
|
|
chrono::DateTime::from_timestamp(model_data.created, 0)
|
|
.map(|dt| dt.format("%Y-%m-%d").to_string())
|
|
.unwrap_or_else(|| "2024-12-12".to_string()),
|
|
),
|
|
capabilities,
|
|
description: Some(description),
|
|
})
|
|
}
|
|
|
|
pub async fn check_connection(&self) -> Result<(), WorkerError> {
|
|
let client = Client::new();
|
|
let url = UrlConfig::get_completion_url("xai");
|
|
|
|
let test_request = XAIRequest {
|
|
model: self.model.clone(),
|
|
messages: vec![XAIMessage {
|
|
role: "user".to_string(),
|
|
content: "Hi".to_string(),
|
|
tool_calls: None,
|
|
}],
|
|
stream: false,
|
|
tools: None,
|
|
max_tokens: Some(10),
|
|
temperature: Some(0.1),
|
|
};
|
|
|
|
let response = client
|
|
.post(url)
|
|
.header("Content-Type", "application/json")
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.json(&test_request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| {
|
|
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::XAI)
|
|
})?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let error_body = response.text().await.unwrap_or_default();
|
|
return Err(WorkerError::from_api_error(
|
|
format!("xAI connection test failed: {} - {}", status, error_body),
|
|
&crate::types::LlmProvider::XAI,
|
|
));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl LlmClientTrait for XAIClient {
|
|
async fn chat_stream<'a>(
|
|
&'a self,
|
|
messages: Vec<Message>,
|
|
tools: Option<&[DynamicToolDefinition]>,
|
|
llm_debug: Option<crate::types::LlmDebug>,
|
|
) -> Result<
|
|
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
|
|
WorkerError,
|
|
> {
|
|
self.chat_stream(messages, tools, llm_debug).await
|
|
}
|
|
|
|
async fn check_connection(&self) -> Result<(), WorkerError> {
|
|
self.check_connection().await
|
|
}
|
|
|
|
fn provider(&self) -> LlmProvider {
|
|
LlmProvider::XAI
|
|
}
|
|
|
|
fn get_model_name(&self) -> String {
|
|
self.get_model_name()
|
|
}
|
|
}
|