llm-worker-rs/worker/src/llm/anthropic.rs

393 lines
15 KiB
Rust

use crate::config::UrlConfig;
use crate::core::LlmClientTrait;
use crate::types::WorkerError;
use async_stream::stream;
use futures_util::{Stream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use worker_types::{LlmProvider, Message, Role, StreamEvent, ToolCall};
#[derive(Debug, Serialize)]
struct AnthropicRequest {
model: String,
max_tokens: i32,
messages: Vec<AnthropicMessage>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<AnthropicTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct AnthropicMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize, Clone)]
struct AnthropicTool {
name: String,
description: String,
input_schema: Value,
}
#[derive(Debug, Deserialize)]
struct AnthropicResponse {
#[serde(rename = "type")]
response_type: String,
content: Vec<AnthropicContent>,
}
#[derive(Debug, Deserialize, Serialize)]
struct AnthropicStreamResponse {
#[serde(rename = "type")]
response_type: String,
#[serde(flatten)]
data: Value,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
enum AnthropicContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: Value,
},
}
pub struct AnthropicClient {
api_key: String,
model: String,
}
impl AnthropicClient {
pub fn new(api_key: &str, model: &str) -> Self {
Self {
api_key: api_key.to_string(),
model: model.to_string(),
}
}
pub fn get_model_name(&self) -> String {
self.model.clone()
}
}
impl AnthropicClient {
pub async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[crate::types::DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
let client = Client::new();
let url = UrlConfig::get_completion_url("anthropic");
// Separate system messages from other messages
let mut system_message: Option<String> = None;
let mut anthropic_messages: Vec<AnthropicMessage> = Vec::new();
for msg in messages {
match msg.role {
Role::System => {
// Combine multiple system messages if they exist
if let Some(existing) = system_message {
system_message = Some(format!(
"{}
{}",
existing, msg.content
));
} else {
system_message = Some(msg.content);
}
}
Role::User => {
anthropic_messages.push(AnthropicMessage {
role: "user".to_string(),
content: msg.content,
});
}
Role::Model => {
anthropic_messages.push(AnthropicMessage {
role: "assistant".to_string(),
content: msg.content,
});
}
Role::Tool => {
anthropic_messages.push(AnthropicMessage {
role: "user".to_string(),
content: msg.content,
});
}
}
}
// Convert tools to Anthropic format
let anthropic_tools = tools.map(|tools| {
tools
.iter()
.map(|tool| AnthropicTool {
name: tool.name.clone(),
description: tool.description.clone(),
input_schema: tool.parameters_schema.clone(),
})
.collect()
});
let request = AnthropicRequest {
model: self.model.clone(),
max_tokens: 4096,
messages: anthropic_messages,
stream: true,
tools: anthropic_tools,
system: system_message,
};
// Log request details for debugging
tracing::debug!(
"Anthropic API request: {}",
serde_json::to_string_pretty(&request).unwrap_or_default()
);
let response = client
.post(url)
.header("Content-Type", "application/json")
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.json(&request)
.send()
.await
.map_err(|e| {
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Claude)
})?;
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
return Err(WorkerError::from_api_error(
format!("Anthropic API error: {} - {}", status, error_body),
&crate::types::LlmProvider::Claude,
));
}
let stream = stream! {
// デバッグ情報を送信
if let Some(ref debug) = llm_debug {
if let Some(debug_event) = debug.debug_request(&self.model, "Anthropic", &serde_json::to_value(&request).unwrap_or_default()) {
yield Ok(debug_event);
}
}
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
let chunk_str = String::from_utf8_lossy(&bytes);
buffer.push_str(&chunk_str);
// Server-sent eventsを処理
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].to_string();
buffer = buffer[line_end + 1..].to_string();
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
break;
}
match serde_json::from_str::<AnthropicStreamResponse>(data) {
Ok(stream_response) => {
// デバッグ情報を送信
if let Some(ref debug) = llm_debug {
if let Some(debug_event) = debug.debug_response(&self.model, "Anthropic", &serde_json::to_value(&stream_response).unwrap_or_default()) {
yield Ok(debug_event);
}
}
match stream_response.response_type.as_str() {
"content_block_delta" => {
if let Some(delta) = stream_response.data.get("delta") {
if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
yield Ok(StreamEvent::Chunk(text.to_string()));
}
}
}
"content_block_start" => {
if let Some(content_block) = stream_response.data.get("content_block") {
if let Some(block_type) = content_block.get("type").and_then(|t| t.as_str()) {
if block_type == "tool_use" {
if let (Some(name), Some(input)) = (
content_block.get("name").and_then(|n| n.as_str()),
content_block.get("input")
) {
let tool_call = ToolCall {
name: name.to_string(),
arguments: input.to_string(),
};
yield Ok(StreamEvent::ToolCall(tool_call));
}
}
}
}
}
"message_start" => {
tracing::debug!("Anthropic message stream started");
}
"message_delta" => {
if let Some(delta) = stream_response.data.get("delta") {
if let Some(stop_reason) = delta.get("stop_reason") {
tracing::debug!("Anthropic message stop reason: {}", stop_reason);
}
}
}
"message_stop" => {
tracing::debug!("Anthropic message stream stopped");
yield Ok(StreamEvent::Completion(Message::new(
Role::Model,
"".to_string(),
)));
break;
}
"content_block_stop" => {
tracing::debug!("Anthropic content block stopped");
}
"ping" => {
tracing::debug!("Anthropic ping received");
}
"error" => {
if let Some(error) = stream_response.data.get("error") {
let error_msg = error.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
tracing::error!("Anthropic stream error: {}", error_msg);
yield Err(WorkerError::from_api_error(
format!("Anthropic stream error: {}", error_msg),
&crate::types::LlmProvider::Claude,
));
}
}
_ => {
tracing::debug!("Unhandled Anthropic stream event: {}", stream_response.response_type);
}
}
}
Err(e) => {
tracing::warn!("Failed to parse Anthropic stream response: {} - Raw data: {}", e, data);
}
}
}
}
}
Err(e) => {
yield Err(WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Claude));
break;
}
}
}
};
Ok(Box::new(Box::pin(stream)))
}
pub async fn check_connection(&self) -> Result<(), WorkerError> {
let client = Client::new();
let url = UrlConfig::get_completion_url("anthropic");
// Use a default valid model for connection testing if model is empty
let test_model = if self.model.is_empty() {
"claude-3-haiku-20240307".to_string()
} else {
self.model.clone()
};
tracing::debug!(
"Anthropic connection test: Using model '{}' with API key length: {}",
test_model,
self.api_key.len()
);
let test_request = AnthropicRequest {
model: test_model,
max_tokens: 1,
messages: vec![AnthropicMessage {
role: "user".to_string(),
content: "Hi".to_string(),
}],
stream: false,
tools: None,
system: None,
};
let response = client
.post(url)
.header("Content-Type", "application/json")
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.json(&test_request)
.send()
.await
.map_err(|e| {
tracing::error!("Anthropic connection test network error: {}", e);
WorkerError::from_api_error(e.to_string(), &crate::types::LlmProvider::Claude)
})?;
if !response.status().is_success() {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
tracing::error!(
"Anthropic connection test failed: Status={}, Body={}",
status,
error_body
);
return Err(WorkerError::from_api_error(
format!(
"Anthropic connection test failed: {} - {}",
status, error_body
),
&crate::types::LlmProvider::Claude,
));
}
Ok(())
}
}
#[async_trait::async_trait]
impl LlmClientTrait for AnthropicClient {
async fn chat_stream<'a>(
&'a self,
messages: Vec<Message>,
tools: Option<&[crate::types::DynamicToolDefinition]>,
llm_debug: Option<crate::types::LlmDebug>,
) -> Result<
Box<dyn Stream<Item = Result<StreamEvent, WorkerError>> + Unpin + Send + 'a>,
WorkerError,
> {
self.chat_stream(messages, tools, llm_debug).await
}
async fn check_connection(&self) -> Result<(), WorkerError> {
self.check_connection().await
}
fn provider(&self) -> LlmProvider {
LlmProvider::Claude
}
fn get_model_name(&self) -> String {
self.get_model_name()
}
}