462 lines
13 KiB
Rust
462 lines
13 KiB
Rust
//! Token Counting and Context Window Management
|
|
//!
|
|
//! This module provides token estimation and context window management
|
|
//! for Claude API conversations.
|
|
|
|
use crate::anthropic::{ContentBlock, Message, DEFAULT_MODEL};
|
|
use crate::conversation::{Conversation, ConversationMessage};
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
/// Claude model context limits (in tokens)
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub struct ModelLimits {
|
|
/// Maximum context window size
|
|
pub context_window: usize,
|
|
/// Maximum output tokens
|
|
pub max_output: usize,
|
|
}
|
|
|
|
impl ModelLimits {
|
|
/// Claude 4.5 Sonnet limits
|
|
pub const CLAUDE_4_5_SONNET: Self = Self {
|
|
context_window: 200_000,
|
|
max_output: 4_096,
|
|
};
|
|
|
|
/// Claude 4.5 Haiku limits
|
|
pub const CLAUDE_4_5_HAIKU: Self = Self {
|
|
context_window: 200_000,
|
|
max_output: 4_096,
|
|
};
|
|
|
|
/// Claude 3 Opus limits
|
|
pub const CLAUDE_3_OPUS: Self = Self {
|
|
context_window: 200_000,
|
|
max_output: 4_096,
|
|
};
|
|
|
|
/// Claude 3 Sonnet limits
|
|
pub const CLAUDE_3_SONNET: Self = Self {
|
|
context_window: 200_000,
|
|
max_output: 4_096,
|
|
};
|
|
|
|
/// Claude 3 Haiku limits
|
|
pub const CLAUDE_3_HAIKU: Self = Self {
|
|
context_window: 200_000,
|
|
max_output: 4_096,
|
|
};
|
|
|
|
/// Get limits for a model name
|
|
pub fn for_model(model: &str) -> Self {
|
|
let lower = model.to_lowercase();
|
|
|
|
if lower.contains("sonnet-4-5") {
|
|
Self::CLAUDE_4_5_SONNET
|
|
} else if lower.contains("haiku-4-5") {
|
|
Self::CLAUDE_4_5_HAIKU
|
|
} else if lower.contains("opus") {
|
|
Self::CLAUDE_3_OPUS
|
|
} else if lower.contains("haiku") {
|
|
Self::CLAUDE_3_HAIKU
|
|
} else {
|
|
Self::CLAUDE_3_SONNET
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Token usage statistics
|
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
|
pub struct TokenUsage {
|
|
/// Estimated input tokens
|
|
pub input_tokens: usize,
|
|
/// Output tokens (from API response)
|
|
pub output_tokens: usize,
|
|
/// Total tokens used
|
|
pub total_tokens: usize,
|
|
}
|
|
|
|
impl TokenUsage {
|
|
/// Create new usage
|
|
pub fn new(input: usize, output: usize) -> Self {
|
|
Self {
|
|
input_tokens: input,
|
|
output_tokens: output,
|
|
total_tokens: input + output,
|
|
}
|
|
}
|
|
|
|
/// Add usage
|
|
pub fn add(&mut self, other: &TokenUsage) {
|
|
self.input_tokens += other.input_tokens;
|
|
self.output_tokens += other.output_tokens;
|
|
self.total_tokens += other.total_tokens;
|
|
}
|
|
}
|
|
|
|
/// Token counter for estimating token usage
|
|
#[derive(Debug, Clone)]
|
|
pub struct TokenCounter {
|
|
/// Model limits
|
|
pub limits: ModelLimits,
|
|
/// Characters per token estimate (Claude uses ~4 chars/token on average)
|
|
pub chars_per_token: f32,
|
|
}
|
|
|
|
impl Default for TokenCounter {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl TokenCounter {
|
|
/// Create a new token counter with default settings
|
|
pub fn new() -> Self {
|
|
Self::with_model(DEFAULT_MODEL)
|
|
}
|
|
|
|
/// Create with specific model limits
|
|
pub fn with_model(model: &str) -> Self {
|
|
Self {
|
|
limits: ModelLimits::for_model(model),
|
|
chars_per_token: 4.0,
|
|
}
|
|
}
|
|
|
|
/// Estimate tokens for text
|
|
pub fn estimate_text(&self, text: &str) -> usize {
|
|
(text.len() as f32 / self.chars_per_token).ceil() as usize
|
|
}
|
|
|
|
/// Estimate tokens for a content block
|
|
pub fn estimate_content_block(&self, block: &ContentBlock) -> usize {
|
|
match block {
|
|
ContentBlock::Text { text } => self.estimate_text(text),
|
|
ContentBlock::ToolUse { name, input, .. } => {
|
|
let input_str = serde_json::to_string(input).unwrap_or_default();
|
|
self.estimate_text(name) + self.estimate_text(&input_str) + 20
|
|
}
|
|
ContentBlock::ToolResult { content, .. } => self.estimate_text(content) + 20,
|
|
}
|
|
}
|
|
|
|
/// Estimate tokens for a message
|
|
pub fn estimate_message(&self, message: &Message) -> usize {
|
|
let content_tokens: usize = message
|
|
.content
|
|
.iter()
|
|
.map(|b| self.estimate_content_block(b))
|
|
.sum();
|
|
|
|
// Add overhead for message structure
|
|
content_tokens + 10
|
|
}
|
|
|
|
/// Estimate tokens for a conversation message
|
|
pub fn estimate_conversation_message(&self, message: &ConversationMessage) -> usize {
|
|
let content_tokens: usize = message
|
|
.content
|
|
.iter()
|
|
.map(|b| self.estimate_content_block(b))
|
|
.sum();
|
|
|
|
content_tokens + 10
|
|
}
|
|
|
|
/// Estimate tokens for a conversation
|
|
pub fn estimate_conversation(&self, conversation: &Conversation) -> usize {
|
|
let mut total = 0;
|
|
|
|
// System prompt
|
|
if let Some(ref prompt) = conversation.system_prompt {
|
|
total += self.estimate_text(prompt) + 10;
|
|
}
|
|
|
|
// Messages
|
|
for message in &conversation.messages {
|
|
total += self.estimate_conversation_message(message);
|
|
}
|
|
|
|
total
|
|
}
|
|
|
|
/// Get available tokens for output
|
|
pub fn available_output_tokens(&self, conversation: &Conversation) -> usize {
|
|
let input = self.estimate_conversation(conversation);
|
|
let available = self.limits.context_window.saturating_sub(input);
|
|
available.min(self.limits.max_output)
|
|
}
|
|
|
|
/// Check if conversation is within limits
|
|
pub fn within_limits(&self, conversation: &Conversation) -> bool {
|
|
let tokens = self.estimate_conversation(conversation);
|
|
tokens < self.limits.context_window - self.limits.max_output
|
|
}
|
|
|
|
/// Get context window size
|
|
pub fn context_window(&self) -> usize {
|
|
self.limits.context_window
|
|
}
|
|
|
|
/// Get max output tokens
|
|
pub fn max_output(&self) -> usize {
|
|
self.limits.max_output
|
|
}
|
|
}
|
|
|
|
/// Context window manager for automatic pruning
|
|
#[derive(Debug, Clone)]
|
|
pub struct ContextManager {
|
|
/// Token counter
|
|
pub counter: TokenCounter,
|
|
/// Target utilization (0.0 - 1.0)
|
|
pub target_utilization: f32,
|
|
/// Minimum messages to keep
|
|
pub min_messages: usize,
|
|
}
|
|
|
|
impl Default for ContextManager {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl ContextManager {
|
|
/// Create a new context manager
|
|
pub fn new() -> Self {
|
|
Self {
|
|
counter: TokenCounter::new(),
|
|
target_utilization: 0.8, // Keep 80% of context for input
|
|
min_messages: 2, // Keep at least 2 messages
|
|
}
|
|
}
|
|
|
|
/// Create with specific model
|
|
pub fn with_model(model: &str) -> Self {
|
|
Self {
|
|
counter: TokenCounter::with_model(model),
|
|
..Default::default()
|
|
}
|
|
}
|
|
|
|
/// Set target utilization
|
|
pub fn with_target_utilization(mut self, utilization: f32) -> Self {
|
|
self.target_utilization = utilization.clamp(0.1, 0.95);
|
|
self
|
|
}
|
|
|
|
/// Set minimum messages to keep
|
|
pub fn with_min_messages(mut self, min: usize) -> Self {
|
|
self.min_messages = min.max(1);
|
|
self
|
|
}
|
|
|
|
/// Get current token estimate
|
|
pub fn estimate_tokens(&self, conversation: &Conversation) -> usize {
|
|
self.counter.estimate_conversation(conversation)
|
|
}
|
|
|
|
/// Get target token limit
|
|
pub fn target_limit(&self) -> usize {
|
|
(self.counter.context_window() as f32 * self.target_utilization) as usize
|
|
}
|
|
|
|
/// Check if pruning is needed
|
|
pub fn needs_pruning(&self, conversation: &Conversation) -> bool {
|
|
self.estimate_tokens(conversation) > self.target_limit()
|
|
}
|
|
|
|
/// Prune conversation to fit within limits
|
|
///
|
|
/// Removes oldest messages while keeping system prompt and minimum messages
|
|
pub fn prune(&self, conversation: &mut Conversation) -> usize {
|
|
let target = self.target_limit();
|
|
let mut removed = 0;
|
|
|
|
while self.estimate_tokens(conversation) > target
|
|
&& conversation.message_count() > self.min_messages
|
|
{
|
|
conversation.messages.remove(0);
|
|
removed += 1;
|
|
}
|
|
|
|
removed
|
|
}
|
|
|
|
/// Get usage statistics
|
|
pub fn get_usage(&self, conversation: &Conversation) -> ContextUsage {
|
|
let tokens = self.estimate_tokens(conversation);
|
|
let limit = self.counter.context_window();
|
|
let available = self.counter.available_output_tokens(conversation);
|
|
|
|
ContextUsage {
|
|
used_tokens: tokens,
|
|
total_tokens: limit,
|
|
available_output: available,
|
|
utilization: tokens as f32 / limit as f32,
|
|
message_count: conversation.message_count(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Context usage statistics
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ContextUsage {
|
|
/// Tokens used by current context
|
|
pub used_tokens: usize,
|
|
/// Total context window size
|
|
pub total_tokens: usize,
|
|
/// Available tokens for output
|
|
pub available_output: usize,
|
|
/// Utilization percentage (0.0 - 1.0)
|
|
pub utilization: f32,
|
|
/// Number of messages
|
|
pub message_count: usize,
|
|
}
|
|
|
|
impl ContextUsage {
|
|
/// Format as display string
|
|
pub fn display(&self) -> String {
|
|
format!(
|
|
"{}/{} tokens ({:.1}%) | {} messages",
|
|
self.used_tokens,
|
|
self.total_tokens,
|
|
self.utilization * 100.0,
|
|
self.message_count
|
|
)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_token_estimation() {
|
|
let counter = TokenCounter::new();
|
|
|
|
// ~4 chars per token
|
|
let tokens = counter.estimate_text("Hello, World!"); // 13 chars
|
|
assert!((3..=5).contains(&tokens));
|
|
}
|
|
|
|
#[test]
|
|
fn test_model_limits() {
|
|
let limits = ModelLimits::for_model("claude-3-opus");
|
|
assert_eq!(limits.context_window, 200_000);
|
|
|
|
let limits = ModelLimits::for_model("claude-3-sonnet");
|
|
assert_eq!(limits.context_window, 200_000);
|
|
|
|
let limits = ModelLimits::for_model("claude-sonnet-4-5-20250929");
|
|
assert_eq!(limits.context_window, 200_000);
|
|
|
|
let limits = ModelLimits::for_model("claude-haiku-4-5-20251001");
|
|
assert_eq!(limits.context_window, 200_000);
|
|
}
|
|
|
|
#[test]
|
|
fn test_conversation_estimation() {
|
|
let counter = TokenCounter::new();
|
|
let mut conv = Conversation::new().with_system_prompt("You are a helpful assistant");
|
|
|
|
conv.add_user_message("Hello");
|
|
conv.add_assistant_message("Hi there!");
|
|
|
|
let tokens = counter.estimate_conversation(&conv);
|
|
assert!(tokens > 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_within_limits() {
|
|
let counter = TokenCounter::new();
|
|
let conv = Conversation::new().with_system_prompt("Test");
|
|
|
|
assert!(counter.within_limits(&conv));
|
|
}
|
|
|
|
#[test]
|
|
fn test_context_manager_creation() {
|
|
let manager = ContextManager::new();
|
|
assert_eq!(manager.target_utilization, 0.8);
|
|
}
|
|
|
|
#[test]
|
|
fn test_context_manager_pruning() {
|
|
// Create manager with small model limits for testing
|
|
let mut manager = ContextManager::new();
|
|
manager.counter.limits = ModelLimits {
|
|
context_window: 100, // Very small for testing
|
|
max_output: 10,
|
|
};
|
|
manager.target_utilization = 0.5; // 50 token target
|
|
|
|
let mut conv = Conversation::new();
|
|
// Add messages that will exceed the limit
|
|
for i in 0..10 {
|
|
// Each message is roughly 30+ tokens
|
|
conv.add_user_message(format!(
|
|
"This is message number {} with substantial content that uses many tokens",
|
|
i
|
|
));
|
|
}
|
|
|
|
let initial = conv.message_count();
|
|
let removed = manager.prune(&mut conv);
|
|
|
|
assert!(
|
|
removed > 0,
|
|
"Expected messages to be removed, tokens: {}",
|
|
manager.estimate_tokens(&conv)
|
|
);
|
|
assert!(conv.message_count() < initial);
|
|
}
|
|
|
|
#[test]
|
|
fn test_context_usage_display() {
|
|
let usage = ContextUsage {
|
|
used_tokens: 1000,
|
|
total_tokens: 200_000,
|
|
available_output: 4096,
|
|
utilization: 0.005,
|
|
message_count: 5,
|
|
};
|
|
|
|
let display = usage.display();
|
|
assert!(display.contains("1000"));
|
|
assert!(display.contains("200000"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_available_output_tokens() {
|
|
let counter = TokenCounter::new();
|
|
let conv = Conversation::new();
|
|
|
|
let available = counter.available_output_tokens(&conv);
|
|
assert_eq!(available, counter.max_output());
|
|
}
|
|
|
|
#[test]
|
|
fn test_token_usage() {
|
|
let mut usage = TokenUsage::new(100, 50);
|
|
assert_eq!(usage.total_tokens, 150);
|
|
|
|
usage.add(&TokenUsage::new(50, 25));
|
|
assert_eq!(usage.total_tokens, 225);
|
|
}
|
|
|
|
#[test]
|
|
fn test_min_messages_preserved() {
|
|
let manager = ContextManager::new()
|
|
.with_target_utilization(0.001)
|
|
.with_min_messages(5);
|
|
|
|
let mut conv = Conversation::new();
|
|
for i in 0..10 {
|
|
conv.add_user_message(format!("Message {}", i));
|
|
}
|
|
|
|
manager.prune(&mut conv);
|
|
assert!(conv.message_count() >= 5);
|
|
}
|
|
}
|