feat(core): implement token counting and context window management
Add token.rs with: - TokenCounter for estimating token usage - ModelLimits for Claude model configurations - ContextManager for automatic context pruning - TokenUsage and ContextUsage statistics - Support for 200k context window - Automatic oldest message pruning - 10 unit tests Closes #21 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
04f7bc897e
commit
277d0e4858
2 changed files with 440 additions and 0 deletions
|
|
@ -8,6 +8,7 @@ pub mod anthropic;
|
|||
pub mod tool;
|
||||
pub mod conversation;
|
||||
pub mod tools;
|
||||
pub mod token;
|
||||
|
||||
pub use error::Error;
|
||||
pub use types::*;
|
||||
|
|
@ -23,3 +24,4 @@ pub use conversation::{
|
|||
Conversation, ConversationMessage, ConversationManager, ConversationMetadata, ConversationError,
|
||||
};
|
||||
pub use tools::{ReadTool, WriteTool, EditTool, create_file_tool_registry};
|
||||
pub use token::{TokenCounter, TokenUsage, ContextManager, ContextUsage, ModelLimits};
|
||||
|
|
|
|||
438
crates/miyabi-core/src/token.rs
Normal file
438
crates/miyabi-core/src/token.rs
Normal file
|
|
@ -0,0 +1,438 @@
|
|||
//! Token Counting and Context Window Management
|
||||
//!
|
||||
//! This module provides token estimation and context window management
|
||||
//! for Claude API conversations.
|
||||
|
||||
use crate::anthropic::{ContentBlock, Message};
|
||||
use crate::conversation::{Conversation, ConversationMessage};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Claude model context limits (in tokens)
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ModelLimits {
|
||||
/// Maximum context window size
|
||||
pub context_window: usize,
|
||||
/// Maximum output tokens
|
||||
pub max_output: usize,
|
||||
}
|
||||
|
||||
impl ModelLimits {
|
||||
/// Claude 3 Opus limits
|
||||
pub const CLAUDE_3_OPUS: Self = Self {
|
||||
context_window: 200_000,
|
||||
max_output: 4_096,
|
||||
};
|
||||
|
||||
/// Claude 3 Sonnet limits
|
||||
pub const CLAUDE_3_SONNET: Self = Self {
|
||||
context_window: 200_000,
|
||||
max_output: 4_096,
|
||||
};
|
||||
|
||||
/// Claude 3 Haiku limits
|
||||
pub const CLAUDE_3_HAIKU: Self = Self {
|
||||
context_window: 200_000,
|
||||
max_output: 4_096,
|
||||
};
|
||||
|
||||
/// Get limits for a model name
|
||||
pub fn for_model(model: &str) -> Self {
|
||||
if model.contains("opus") {
|
||||
Self::CLAUDE_3_OPUS
|
||||
} else if model.contains("haiku") {
|
||||
Self::CLAUDE_3_HAIKU
|
||||
} else {
|
||||
Self::CLAUDE_3_SONNET
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Token usage statistics
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct TokenUsage {
|
||||
/// Estimated input tokens
|
||||
pub input_tokens: usize,
|
||||
/// Output tokens (from API response)
|
||||
pub output_tokens: usize,
|
||||
/// Total tokens used
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
impl TokenUsage {
|
||||
/// Create new usage
|
||||
pub fn new(input: usize, output: usize) -> Self {
|
||||
Self {
|
||||
input_tokens: input,
|
||||
output_tokens: output,
|
||||
total_tokens: input + output,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add usage
|
||||
pub fn add(&mut self, other: &TokenUsage) {
|
||||
self.input_tokens += other.input_tokens;
|
||||
self.output_tokens += other.output_tokens;
|
||||
self.total_tokens += other.total_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
/// Token counter for estimating token usage
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TokenCounter {
|
||||
/// Model limits
|
||||
pub limits: ModelLimits,
|
||||
/// Characters per token estimate (Claude uses ~4 chars/token on average)
|
||||
pub chars_per_token: f32,
|
||||
}
|
||||
|
||||
impl Default for TokenCounter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenCounter {
|
||||
/// Create a new token counter with default settings
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
limits: ModelLimits::CLAUDE_3_SONNET,
|
||||
chars_per_token: 4.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with specific model limits
|
||||
pub fn with_model(model: &str) -> Self {
|
||||
Self {
|
||||
limits: ModelLimits::for_model(model),
|
||||
chars_per_token: 4.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate tokens for text
|
||||
pub fn estimate_text(&self, text: &str) -> usize {
|
||||
(text.len() as f32 / self.chars_per_token).ceil() as usize
|
||||
}
|
||||
|
||||
/// Estimate tokens for a content block
|
||||
pub fn estimate_content_block(&self, block: &ContentBlock) -> usize {
|
||||
match block {
|
||||
ContentBlock::Text { text } => self.estimate_text(text),
|
||||
ContentBlock::ToolUse { name, input, .. } => {
|
||||
let input_str = serde_json::to_string(input).unwrap_or_default();
|
||||
self.estimate_text(name) + self.estimate_text(&input_str) + 20
|
||||
}
|
||||
ContentBlock::ToolResult { content, .. } => {
|
||||
self.estimate_text(content) + 20
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate tokens for a message
|
||||
pub fn estimate_message(&self, message: &Message) -> usize {
|
||||
let content_tokens: usize = message
|
||||
.content
|
||||
.iter()
|
||||
.map(|b| self.estimate_content_block(b))
|
||||
.sum();
|
||||
|
||||
// Add overhead for message structure
|
||||
content_tokens + 10
|
||||
}
|
||||
|
||||
/// Estimate tokens for a conversation message
|
||||
pub fn estimate_conversation_message(&self, message: &ConversationMessage) -> usize {
|
||||
let content_tokens: usize = message
|
||||
.content
|
||||
.iter()
|
||||
.map(|b| self.estimate_content_block(b))
|
||||
.sum();
|
||||
|
||||
content_tokens + 10
|
||||
}
|
||||
|
||||
/// Estimate tokens for a conversation
|
||||
pub fn estimate_conversation(&self, conversation: &Conversation) -> usize {
|
||||
let mut total = 0;
|
||||
|
||||
// System prompt
|
||||
if let Some(ref prompt) = conversation.system_prompt {
|
||||
total += self.estimate_text(prompt) + 10;
|
||||
}
|
||||
|
||||
// Messages
|
||||
for message in &conversation.messages {
|
||||
total += self.estimate_conversation_message(message);
|
||||
}
|
||||
|
||||
total
|
||||
}
|
||||
|
||||
/// Get available tokens for output
|
||||
pub fn available_output_tokens(&self, conversation: &Conversation) -> usize {
|
||||
let input = self.estimate_conversation(conversation);
|
||||
let available = self.limits.context_window.saturating_sub(input);
|
||||
available.min(self.limits.max_output)
|
||||
}
|
||||
|
||||
/// Check if conversation is within limits
|
||||
pub fn within_limits(&self, conversation: &Conversation) -> bool {
|
||||
let tokens = self.estimate_conversation(conversation);
|
||||
tokens < self.limits.context_window - self.limits.max_output
|
||||
}
|
||||
|
||||
/// Get context window size
|
||||
pub fn context_window(&self) -> usize {
|
||||
self.limits.context_window
|
||||
}
|
||||
|
||||
/// Get max output tokens
|
||||
pub fn max_output(&self) -> usize {
|
||||
self.limits.max_output
|
||||
}
|
||||
}
|
||||
|
||||
/// Context window manager for automatic pruning
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextManager {
|
||||
/// Token counter
|
||||
pub counter: TokenCounter,
|
||||
/// Target utilization (0.0 - 1.0)
|
||||
pub target_utilization: f32,
|
||||
/// Minimum messages to keep
|
||||
pub min_messages: usize,
|
||||
}
|
||||
|
||||
impl Default for ContextManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ContextManager {
|
||||
/// Create a new context manager
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
counter: TokenCounter::new(),
|
||||
target_utilization: 0.8, // Keep 80% of context for input
|
||||
min_messages: 2, // Keep at least 2 messages
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with specific model
|
||||
pub fn with_model(model: &str) -> Self {
|
||||
Self {
|
||||
counter: TokenCounter::with_model(model),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Set target utilization
|
||||
pub fn with_target_utilization(mut self, utilization: f32) -> Self {
|
||||
self.target_utilization = utilization.clamp(0.1, 0.95);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set minimum messages to keep
|
||||
pub fn with_min_messages(mut self, min: usize) -> Self {
|
||||
self.min_messages = min.max(1);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get current token estimate
|
||||
pub fn estimate_tokens(&self, conversation: &Conversation) -> usize {
|
||||
self.counter.estimate_conversation(conversation)
|
||||
}
|
||||
|
||||
/// Get target token limit
|
||||
pub fn target_limit(&self) -> usize {
|
||||
(self.counter.context_window() as f32 * self.target_utilization) as usize
|
||||
}
|
||||
|
||||
/// Check if pruning is needed
|
||||
pub fn needs_pruning(&self, conversation: &Conversation) -> bool {
|
||||
self.estimate_tokens(conversation) > self.target_limit()
|
||||
}
|
||||
|
||||
/// Prune conversation to fit within limits
|
||||
///
|
||||
/// Removes oldest messages while keeping system prompt and minimum messages
|
||||
pub fn prune(&self, conversation: &mut Conversation) -> usize {
|
||||
let target = self.target_limit();
|
||||
let mut removed = 0;
|
||||
|
||||
while self.estimate_tokens(conversation) > target
|
||||
&& conversation.message_count() > self.min_messages
|
||||
{
|
||||
conversation.messages.remove(0);
|
||||
removed += 1;
|
||||
}
|
||||
|
||||
removed
|
||||
}
|
||||
|
||||
/// Get usage statistics
|
||||
pub fn get_usage(&self, conversation: &Conversation) -> ContextUsage {
|
||||
let tokens = self.estimate_tokens(conversation);
|
||||
let limit = self.counter.context_window();
|
||||
let available = self.counter.available_output_tokens(conversation);
|
||||
|
||||
ContextUsage {
|
||||
used_tokens: tokens,
|
||||
total_tokens: limit,
|
||||
available_output: available,
|
||||
utilization: tokens as f32 / limit as f32,
|
||||
message_count: conversation.message_count(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context usage statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ContextUsage {
|
||||
/// Tokens used by current context
|
||||
pub used_tokens: usize,
|
||||
/// Total context window size
|
||||
pub total_tokens: usize,
|
||||
/// Available tokens for output
|
||||
pub available_output: usize,
|
||||
/// Utilization percentage (0.0 - 1.0)
|
||||
pub utilization: f32,
|
||||
/// Number of messages
|
||||
pub message_count: usize,
|
||||
}
|
||||
|
||||
impl ContextUsage {
|
||||
/// Format as display string
|
||||
pub fn display(&self) -> String {
|
||||
format!(
|
||||
"{}/{} tokens ({:.1}%) | {} messages",
|
||||
self.used_tokens,
|
||||
self.total_tokens,
|
||||
self.utilization * 100.0,
|
||||
self.message_count
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_token_estimation() {
|
||||
let counter = TokenCounter::new();
|
||||
|
||||
// ~4 chars per token
|
||||
let tokens = counter.estimate_text("Hello, World!"); // 13 chars
|
||||
assert!(tokens >= 3 && tokens <= 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_limits() {
|
||||
let limits = ModelLimits::for_model("claude-3-opus");
|
||||
assert_eq!(limits.context_window, 200_000);
|
||||
|
||||
let limits = ModelLimits::for_model("claude-3-sonnet");
|
||||
assert_eq!(limits.context_window, 200_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_estimation() {
|
||||
let counter = TokenCounter::new();
|
||||
let mut conv = Conversation::new()
|
||||
.with_system_prompt("You are a helpful assistant");
|
||||
|
||||
conv.add_user_message("Hello");
|
||||
conv.add_assistant_message("Hi there!");
|
||||
|
||||
let tokens = counter.estimate_conversation(&conv);
|
||||
assert!(tokens > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_within_limits() {
|
||||
let counter = TokenCounter::new();
|
||||
let conv = Conversation::new()
|
||||
.with_system_prompt("Test");
|
||||
|
||||
assert!(counter.within_limits(&conv));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_manager_creation() {
|
||||
let manager = ContextManager::new();
|
||||
assert_eq!(manager.target_utilization, 0.8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_manager_pruning() {
|
||||
// Create manager with small model limits for testing
|
||||
let mut manager = ContextManager::new();
|
||||
manager.counter.limits = ModelLimits {
|
||||
context_window: 100, // Very small for testing
|
||||
max_output: 10,
|
||||
};
|
||||
manager.target_utilization = 0.5; // 50 token target
|
||||
|
||||
let mut conv = Conversation::new();
|
||||
// Add messages that will exceed the limit
|
||||
for i in 0..10 {
|
||||
// Each message is roughly 30+ tokens
|
||||
conv.add_user_message(format!("This is message number {} with substantial content that uses many tokens", i));
|
||||
}
|
||||
|
||||
let initial = conv.message_count();
|
||||
let removed = manager.prune(&mut conv);
|
||||
|
||||
assert!(removed > 0, "Expected messages to be removed, tokens: {}", manager.estimate_tokens(&conv));
|
||||
assert!(conv.message_count() < initial);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_usage_display() {
|
||||
let usage = ContextUsage {
|
||||
used_tokens: 1000,
|
||||
total_tokens: 200_000,
|
||||
available_output: 4096,
|
||||
utilization: 0.005,
|
||||
message_count: 5,
|
||||
};
|
||||
|
||||
let display = usage.display();
|
||||
assert!(display.contains("1000"));
|
||||
assert!(display.contains("200000"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_available_output_tokens() {
|
||||
let counter = TokenCounter::new();
|
||||
let conv = Conversation::new();
|
||||
|
||||
let available = counter.available_output_tokens(&conv);
|
||||
assert_eq!(available, counter.max_output());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_usage() {
|
||||
let mut usage = TokenUsage::new(100, 50);
|
||||
assert_eq!(usage.total_tokens, 150);
|
||||
|
||||
usage.add(&TokenUsage::new(50, 25));
|
||||
assert_eq!(usage.total_tokens, 225);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_messages_preserved() {
|
||||
let manager = ContextManager::new()
|
||||
.with_target_utilization(0.001)
|
||||
.with_min_messages(5);
|
||||
|
||||
let mut conv = Conversation::new();
|
||||
for i in 0..10 {
|
||||
conv.add_user_message(format!("Message {}", i));
|
||||
}
|
||||
|
||||
manager.prune(&mut conv);
|
||||
assert!(conv.message_count() >= 5);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue