393 lines
15 KiB
Rust
393 lines
15 KiB
Rust
use crate::config::UrlConfig;
|
|
use crate::core::LlmClientTrait;
|
|
use crate::types::WorkerError;
|
|
use async_stream::stream;
|
|
use futures_util::{Stream, StreamExt};
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
use worker_types::{LlmProvider, Message, Role, StreamEvent, ToolCall};
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct AnthropicRequest {
|
|
model: String,
|
|
max_tokens: i32,
|
|
messages: Vec<AnthropicMessage>,
|
|
stream: bool,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tools: Option<Vec<AnthropicTool>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
system: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
struct AnthropicMessage {
|
|
role: String,
|
|
content: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Clone)]
|
|
struct AnthropicTool {
|
|
name: String,
|
|
description: String,
|
|
input_schema: Value,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct AnthropicResponse {
|
|
#[serde(rename = "type")]
|
|
response_type: String,
|
|
content: Vec<AnthropicContent>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize)]
|
|
struct AnthropicStreamResponse {
|
|
#[serde(rename = "type")]
|
|
response_type: String,
|
|
#[serde(flatten)]
|
|
data: Value,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(tag = "type")]
|
|
enum AnthropicContent {
|
|
#[serde(rename = "text")]
|
|
Text { text: String },
|
|
#[serde(rename = "tool_use")]
|
|
ToolUse {
|
|
id: String,
|
|
name: String,
|
|
input: Value,
|
|
},
|
|
}
|
|
|
|
pub struct AnthropicClient {
|
|
api_key: String,
|
|
model: String,
|
|
}
|
|
|
|
impl AnthropicClient {
|
|
pub fn new(api_key: &str, model: &str) -> Self {
|
|
Self {
|
|
api_key: api_key.to_string(),
|
|
model: model.to_string(),
|
|
}
|
|
}
|
|
|
|
pub fn get_model_name(&self) -> String {
|
|
self.model.clone()
|
|
}
|
|
}
|
|
|
|
impl AnthropicClient {
|
|
pub async fn chat_stream<'a>(
|
|
&'a self,
|
|
messages: Vec<Message>,
|
|
tools: Option<&[crate::types::DynamicToolDefinition]>,
|
|
llm_debug: Option<crate::types::LlmDebug>,
|
|
) -> Result<
|
|
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
|
|
WorkerError,
|
|
> {
|
|
let client = Client::new();
|
|
let url = UrlConfig::get_completion_url("anthropic");
|
|
|
|
// Separate system messages from other messages
|
|
let mut system_message: Option<String> = None;
|
|
let mut anthropic_messages: Vec<AnthropicMessage> = Vec::new();
|
|
|
|
for msg in messages {
|
|
match msg.role {
|
|
Role::System => {
|
|
// Combine multiple system messages if they exist
|
|
if let Some(existing) = system_message {
|
|
system_message = Some(format!(
|
|
"{}
|
|
|
|
{}",
|
|
existing, msg.content
|
|
));
|
|
} else {
|
|
system_message = Some(msg.content);
|
|
}
|
|
}
|
|
Role::User => {
|
|
anthropic_messages.push(AnthropicMessage {
|
|
role: "user".to_string(),
|
|
content: msg.content,
|
|
});
|
|
}
|
|
Role::Model => {
|
|
anthropic_messages.push(AnthropicMessage {
|
|
role: "assistant".to_string(),
|
|
content: msg.content,
|
|
});
|
|
}
|
|
Role::Tool => {
|
|
anthropic_messages.push(AnthropicMessage {
|
|
role: "user".to_string(),
|
|
content: msg.content,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert tools to Anthropic format
|
|
let anthropic_tools = tools.map(|tools| {
|
|
tools
|
|
.iter()
|
|
.map(|tool| AnthropicTool {
|
|
name: tool.name.clone(),
|
|
description: tool.description.clone(),
|
|
input_schema: tool.parameters_schema.clone(),
|
|
})
|
|
.collect()
|
|
});
|
|
|
|
let request = AnthropicRequest {
|
|
model: self.model.clone(),
|
|
max_tokens: 4096,
|
|
messages: anthropic_messages,
|
|
stream: true,
|
|
tools: anthropic_tools,
|
|
system: system_message,
|
|
};
|
|
|
|
// Log request details for debugging
|
|
tracing::debug!(
|
|
"Anthropic API request: {}",
|
|
serde_json::to_string_pretty(&request).unwrap_or_default()
|
|
);
|
|
|
|
let response = client
|
|
.post(url)
|
|
.header("Content-Type", "application/json")
|
|
.header("x-api-key", &self.api_key)
|
|
.header("anthropic-version", "2023-06-01")
|
|
.json(&request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| {
|
|
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Claude)
|
|
})?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let error_body = response.text().await.unwrap_or_default();
|
|
return Err(WorkerError::from_api_error(
|
|
format!("Anthropic API error: {} - {}", status, error_body),
|
|
&crate::types::LlmProvider::Claude,
|
|
));
|
|
}
|
|
|
|
let stream = stream! {
|
|
// デバッグ情報を送信
|
|
if let Some(ref debug) = llm_debug {
|
|
if let Some(debug_event) = debug.debug_request(&self.model, "Anthropic", &serde_json::to_value(&request).unwrap_or_default()) {
|
|
yield Ok(debug_event);
|
|
}
|
|
}
|
|
|
|
let mut stream = response.bytes_stream();
|
|
let mut buffer = String::new();
|
|
|
|
while let Some(chunk) = stream.next().await {
|
|
match chunk {
|
|
Ok(bytes) => {
|
|
let chunk_str = String::from_utf8_lossy(&bytes);
|
|
buffer.push_str(&chunk_str);
|
|
|
|
// Server-sent eventsを処理
|
|
while let Some(line_end) = buffer.find('\n') {
|
|
let line = buffer[..line_end].to_string();
|
|
buffer = buffer[line_end + 1..].to_string();
|
|
|
|
if line.starts_with("data: ") {
|
|
let data = &line[6..];
|
|
if data == "[DONE]" {
|
|
break;
|
|
}
|
|
|
|
match serde_json::from_str::<AnthropicStreamResponse>(data) {
|
|
Ok(stream_response) => {
|
|
// デバッグ情報を送信
|
|
if let Some(ref debug) = llm_debug {
|
|
if let Some(debug_event) = debug.debug_response(&self.model, "Anthropic", &serde_json::to_value(&stream_response).unwrap_or_default()) {
|
|
yield Ok(debug_event);
|
|
}
|
|
}
|
|
match stream_response.response_type.as_str() {
|
|
"content_block_delta" => {
|
|
if let Some(delta) = stream_response.data.get("delta") {
|
|
if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
|
|
yield Ok(StreamEvent::Chunk(text.to_string()));
|
|
}
|
|
}
|
|
}
|
|
"content_block_start" => {
|
|
if let Some(content_block) = stream_response.data.get("content_block") {
|
|
if let Some(block_type) = content_block.get("type").and_then(|t| t.as_str()) {
|
|
if block_type == "tool_use" {
|
|
if let (Some(name), Some(input)) = (
|
|
content_block.get("name").and_then(|n| n.as_str()),
|
|
content_block.get("input")
|
|
) {
|
|
let tool_call = ToolCall {
|
|
name: name.to_string(),
|
|
arguments: input.to_string(),
|
|
};
|
|
yield Ok(StreamEvent::ToolCall(tool_call));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
"message_start" => {
|
|
tracing::debug!("Anthropic message stream started");
|
|
}
|
|
"message_delta" => {
|
|
if let Some(delta) = stream_response.data.get("delta") {
|
|
if let Some(stop_reason) = delta.get("stop_reason") {
|
|
tracing::debug!("Anthropic message stop reason: {}", stop_reason);
|
|
}
|
|
}
|
|
}
|
|
"message_stop" => {
|
|
tracing::debug!("Anthropic message stream stopped");
|
|
yield Ok(StreamEvent::Completion(Message::new(
|
|
Role::Model,
|
|
"".to_string(),
|
|
)));
|
|
break;
|
|
}
|
|
"content_block_stop" => {
|
|
tracing::debug!("Anthropic content block stopped");
|
|
}
|
|
"ping" => {
|
|
tracing::debug!("Anthropic ping received");
|
|
}
|
|
"error" => {
|
|
if let Some(error) = stream_response.data.get("error") {
|
|
let error_msg = error.get("message")
|
|
.and_then(|m| m.as_str())
|
|
.unwrap_or("Unknown error");
|
|
tracing::error!("Anthropic stream error: {}", error_msg);
|
|
yield Err(WorkerError::from_api_error(
|
|
format!("Anthropic stream error: {}", error_msg),
|
|
&crate::types::LlmProvider::Claude,
|
|
));
|
|
}
|
|
}
|
|
_ => {
|
|
tracing::debug!("Unhandled Anthropic stream event: {}", stream_response.response_type);
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!("Failed to parse Anthropic stream response: {} - Raw data: {}", e, data);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
yield Err(WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Claude));
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(Box::new(Box::pin(stream)))
|
|
}
|
|
|
|
pub async fn check_connection(&self) -> Result<(), WorkerError> {
|
|
let client = Client::new();
|
|
let url = UrlConfig::get_completion_url("anthropic");
|
|
|
|
// Use a default valid model for connection testing if model is empty
|
|
let test_model = if self.model.is_empty() {
|
|
"claude-3-haiku-20240307".to_string()
|
|
} else {
|
|
self.model.clone()
|
|
};
|
|
|
|
tracing::debug!(
|
|
"Anthropic connection test: Using model '{}' with API key length: {}",
|
|
test_model,
|
|
self.api_key.len()
|
|
);
|
|
|
|
let test_request = AnthropicRequest {
|
|
model: test_model,
|
|
max_tokens: 1,
|
|
messages: vec![AnthropicMessage {
|
|
role: "user".to_string(),
|
|
content: "Hi".to_string(),
|
|
}],
|
|
stream: false,
|
|
tools: None,
|
|
system: None,
|
|
};
|
|
|
|
let response = client
|
|
.post(url)
|
|
.header("Content-Type", "application/json")
|
|
.header("x-api-key", &self.api_key)
|
|
.header("anthropic-version", "2023-06-01")
|
|
.json(&test_request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| {
|
|
tracing::error!("Anthropic connection test network error: {}", e);
|
|
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Claude)
|
|
})?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let error_body = response.text().await.unwrap_or_default();
|
|
tracing::error!(
|
|
"Anthropic connection test failed: Status={}, Body={}",
|
|
status,
|
|
error_body
|
|
);
|
|
return Err(WorkerError::from_api_error(
|
|
format!(
|
|
"Anthropic connection test failed: {} - {}",
|
|
status, error_body
|
|
),
|
|
&crate::types::LlmProvider::Claude,
|
|
));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl LlmClientTrait for AnthropicClient {
|
|
async fn chat_stream<'a>(
|
|
&'a self,
|
|
messages: Vec<Message>,
|
|
tools: Option<&[crate::types::DynamicToolDefinition]>,
|
|
llm_debug: Option<crate::types::LlmDebug>,
|
|
) -> Result<
|
|
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
|
|
WorkerError,
|
|
> {
|
|
self.chat_stream(messages, tools, llm_debug).await
|
|
}
|
|
|
|
async fn check_connection(&self) -> Result<(), WorkerError> {
|
|
self.check_connection().await
|
|
}
|
|
|
|
fn provider(&self) -> LlmProvider {
|
|
LlmProvider::Claude
|
|
}
|
|
|
|
fn get_model_name(&self) -> String {
|
|
self.get_model_name()
|
|
}
|
|
}
|