Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for BaseRefinery and OverlapRefinery + minor changes #78

Merged
merged 32 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e938903
[Add] Implement base refinery classes and refined chunk dataclass
bhavnicksm Nov 25, 2024
4f6157a
Remove slots in the Chunk dataclasses
bhavnicksm Nov 29, 2024
c363ab6
Added automated testing using Github Actions
pratyushmittal Nov 25, 2024
d249caa
Run ruff and fix linting errors, add min_chunk_size to chunkers
bhavnicksm Nov 25, 2024
1fd463e
Run ruff and fix linting errors
bhavnicksm Nov 25, 2024
be71355
style: add ruff checks and fix docstrings
bhavnicksm Nov 25, 2024
646db3c
chore: update dev dependencies; use ruff for linting
bhavnicksm Nov 25, 2024
e34f9e0
Chore: Add docstrings for test files in chunker module
bhavnicksm Nov 25, 2024
634e477
Chore: run ruff format
bhavnicksm Nov 25, 2024
45f281d
Skip tests if openai tests if OPENAI_API_KEY is not defined
pratyushmittal Nov 25, 2024
ebfc10a
Chore: Update Ruff rules in pyproject.toml
bhavnicksm Nov 25, 2024
5430d8e
Fix: Change from tokenizer -> tokenizer_or_token_counter
bhavnicksm Nov 25, 2024
bd249e5
Chore: Update GitHub Actions workflow for Python testing
bhavnicksm Nov 25, 2024
bd7d911
[Fix] Allow for functions as token counters in BaseChunker
bhavnicksm Nov 25, 2024
376ba61
[Chore] remove unused code in BaseChunker
bhavnicksm Nov 25, 2024
3f6961d
Add TEVL to speed up sentence chunker
bhavnicksm Nov 26, 2024
fae0026
[chore] run ruff linting
bhavnicksm Nov 26, 2024
784ce19
Remove slots in Chunk dataclasses
bhavnicksm Nov 29, 2024
285e765
Update the chunk dataclass with copy, repr, and len methods
bhavnicksm Nov 29, 2024
fc861e4
Update the Chunk dataclasses to not use slots
bhavnicksm Nov 29, 2024
fb37573
Add Context and Refinery Classes to Chonkie
bhavnicksm Dec 2, 2024
719e33b
Enhance SemanticChunker with error handling and similarity threshold …
bhavnicksm Dec 4, 2024
aa1fe0a
Merge branch 'development' into refinery
bhavnicksm Dec 4, 2024
71a9d5d
Merge pull request #77 from bhavnicksm/refinery
bhavnicksm Dec 4, 2024
2652453
Replace TokenFactory with TokenProcessor
bhavnicksm Dec 5, 2024
7ebf586
Merge branch 'main' into development
bhavnicksm Dec 5, 2024
3601aed
Refactor TokenProcessor class for improved clarity and structure
bhavnicksm Dec 5, 2024
7152e99
Refactor OverlapRefinery and BaseRefinery for improved readability an…
bhavnicksm Dec 5, 2024
3781536
Enhance OverlapRefinery to support prefix and suffix context modes
bhavnicksm Dec 5, 2024
621061f
Refactor code for improved readability and consistency
bhavnicksm Dec 5, 2024
e17e6d0
Refactor OverlapRefinery context handling and update tests
bhavnicksm Dec 5, 2024
ea04076
Update tests in OverlapRefinery to validate context handling
bhavnicksm Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor OverlapRefinery and BaseRefinery for improved readability an…
…d structure

- Cleaned up whitespace and formatting in the OverlapRefinery and BaseRefinery classes.
- Updated docstrings for clarity and consistency.
- Adjusted method signatures and internal logic for better readability.
- Ensured consistent use of commas and spacing in function definitions and calls.
  • Loading branch information
bhavnicksm committed Dec 5, 2024
commit 7152e99b6ed8b416b81c75ffd5efca160581b259
1 change: 0 additions & 1 deletion src/chonkie/refinery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@

# Include all the refinery classes in the __all__ list
__all__ = ["BaseRefinery", "OverlapRefinery"]

9 changes: 5 additions & 4 deletions src/chonkie/refinery/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
from typing import Any, List
from typing import List

from chonkie.chunker import Chunk


class BaseRefinery(ABC):
"""Base class for all Refinery classes.

Refinery classes are used to refine the Chunks generated from the
Refinery classes are used to refine the Chunks generated from the
Chunkers. These classes take in chunks and return refined chunks.
Most refinery classes would be used to add additional context to the
chunks generated by the chunkers.
Expand All @@ -32,7 +33,7 @@ def is_available(cls) -> bool:
def __repr__(self) -> str:
"""Representation of the Refinery."""
return f"{self.__class__.__name__}(context_size={self.context_size})"

def __call__(self, chunks: List[Chunk]) -> List[Chunk]:
"""Call the Refinery."""
return self.refine(chunks)
return self.refine(chunks)
137 changes: 72 additions & 65 deletions src/chonkie/refinery/overlap.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""Refinery class which adds overlap as context to chunks."""
from typing import Any, List, Optional
from dataclasses import dataclass

from chonkie.chunker import Chunk, SentenceChunk, SemanticChunk
from chonkie.chunker import Chunk, SemanticChunk, SentenceChunk
from chonkie.context import Context
from chonkie.refinery.base import BaseRefinery

from chonkie.context import Context

class OverlapRefinery(BaseRefinery):
"""Refinery class which adds overlap as context to chunks.

This refinery provides two methods for calculating overlap:
1. Exact: Uses a tokenizer to precisely determine token boundaries
2. Approximate: Estimates tokens based on text length ratios

It can handle different types of chunks (basic Chunks, SentenceChunks,
and SemanticChunks) and can optionally update the chunk text to include
the overlap content.
Expand All @@ -24,10 +24,10 @@ def __init__(
tokenizer: Any = None,
merge_context: bool = True,
inplace: bool = True,
approximate: bool = True
approximate: bool = True,
) -> None:
"""Initialize the OverlapRefinery class.

Args:
context_size: Number of tokens to include in context
tokenizer: Optional tokenizer for exact token counting
Expand All @@ -39,7 +39,7 @@ def __init__(
super().__init__(context_size)
self.merge_context = merge_context
self.inplace = inplace

# If tokenizer provided, we can do exact token counting
if tokenizer is not None:
self.tokenizer = tokenizer
Expand All @@ -48,24 +48,26 @@ def __init__(
# Without tokenizer, must use approximate method
self.approximate = True

def _get_refined_chunks(self, chunks: List[Chunk], inplace: bool = True) -> List[Chunk]:
def _get_refined_chunks(
self, chunks: List[Chunk], inplace: bool = True
) -> List[Chunk]:
"""Convert regular chunks to refined chunks with progressive memory cleanup.

This method takes regular chunks and converts them to RefinedChunks one at a
time. When inplace is True, it progressively removes chunks from the input
list to minimize memory usage.

The conversion preserves all relevant information from the original chunks,
including sentences and embeddings if they exist. This allows us to maintain
the full capabilities of semantic chunks while adding refinement features.

Args:
chunks: List of original chunks to convert
inplace: Whether to modify the input list during conversion

Returns:
List of RefinedChunks without any context (context is added later)

Example:
For memory efficiency with large datasets:
```
Expand All @@ -77,9 +79,9 @@ def _get_refined_chunks(self, chunks: List[Chunk], inplace: bool = True) -> List
"""
if not chunks:
return []

refined_chunks = []

# Use enumerate to track position without modifying list during iteration
for i in range(len(chunks)):
if inplace:
Expand All @@ -88,131 +90,132 @@ def _get_refined_chunks(self, chunks: List[Chunk], inplace: bool = True) -> List
else:
# Just get a reference if not modifying in place
chunk = chunks[i]

# Create refined version preserving appropriate attributes
refined_chunk = SemanticChunk(
text=chunk.text,
start_index=chunk.start_index,
end_index=chunk.end_index,
token_count=chunk.token_count,
# Preserve sentences and embeddings if they exist
sentences=chunk.sentences if isinstance(chunk, (SentenceChunk, SemanticChunk)) else None,
sentences=chunk.sentences
if isinstance(chunk, (SentenceChunk, SemanticChunk))
else None,
embedding=chunk.embedding if isinstance(chunk, SemanticChunk) else None,
context=None # Context is added later in the refinement process
context=None, # Context is added later in the refinement process
)

refined_chunks.append(refined_chunk)

if inplace:
# Clear the input list to free memory
chunks.clear()
chunks += refined_chunks

return refined_chunks

def _overlap_token_exact(self, chunk: Chunk) -> Optional[Context]:
"""Calculate precise token-based overlap context using tokenizer.

Takes a larger window of text from the chunk end, tokenizes it,
and selects exactly context_size tokens worth of text.

Args:
chunk: Chunk to extract context from

Returns:
Context object with precise token boundaries, or None if no tokenizer

"""
if not hasattr(self, 'tokenizer'):
if not hasattr(self, "tokenizer"):
return None

# Take 6x context_size characters to ensure enough tokens
char_window = min(len(chunk.text), self.context_size * 6)
text_portion = chunk.text[-char_window:]

# Get exact token boundaries
tokens = self.tokenizer.encode(text_portion)
context_tokens = min(self.context_size, len(tokens))
context_tokens_ids = tokens[-context_tokens:]
context_text = self.tokenizer.decode(context_tokens_ids)

# Find where context text starts in chunk
try:
context_start = chunk.text.rindex(context_text)
start_index = chunk.start_index + context_start

return Context(
text=context_text,
token_count=context_tokens,
start_index=start_index,
end_index=chunk.end_index
end_index=chunk.end_index,
)
except ValueError:
# If context text can't be found (e.g., due to special tokens), fall back to approximate
return self._overlap_token_approximate(chunk)

def _overlap_token_approximate(self, chunk: Chunk) -> Optional[Context]:
"""Calculate approximate token-based overlap context.

Estimates token positions based on character length ratios.

Args:
chunk: Chunk to extract context from

Returns:
Context object with estimated token boundaries

"""
# Calculate desired context size
context_tokens = min(self.context_size, chunk.token_count)

# Estimate text length based on token ratio
context_ratio = context_tokens / chunk.token_count
char_length = int(len(chunk.text) * context_ratio)

# Extract context text from end
context_text = chunk.text[-char_length:]

return Context(
text=context_text,
token_count=context_tokens,
start_index=chunk.end_index - char_length,
end_index=chunk.end_index
end_index=chunk.end_index,
)


def _overlap_token(self, chunk: Chunk) -> Optional[Context]:
"""Choose between exact or approximate token overlap calculation.

Args:
chunk: Chunk to process

Returns:
Context object from either exact or approximate calculation

"""
if self.approximate:
return self._overlap_token_approximate(chunk)
return self._overlap_token_exact(chunk)

def _overlap_sentence(self, chunk: SentenceChunk) -> Optional[Context]:
"""Calculate overlap context based on sentences.

Takes sentences from the end of the chunk up to context_size tokens.

Args:
chunk: SentenceChunk to process

Returns:
Context object containing complete sentences

"""
if not chunk.sentences:
return None

context_sentences = []
total_tokens = 0

# Add sentences from the end until we hit context_size
for sentence in reversed(chunk.sentences):
if total_tokens + sentence.token_count <= self.context_size:
Expand All @@ -224,14 +227,14 @@ def _overlap_sentence(self, chunk: SentenceChunk) -> Optional[Context]:
if not context_sentences:
context_sentences.append(chunk.sentences[-1])
total_tokens = chunk.sentences[-1].token_count

return Context(
text=" ".join(s.text for s in context_sentences),
token_count=total_tokens,
start_index=context_sentences[0].start_index,
end_index=context_sentences[-1].end_index
end_index=context_sentences[-1].end_index,
)

def _get_overlap_context(self, chunk: Chunk) -> Optional[Context]:
"""Get appropriate overlap context based on chunk type."""
if isinstance(chunk, SemanticChunk):
Expand All @@ -242,16 +245,16 @@ def _get_overlap_context(self, chunk: Chunk) -> Optional[Context]:
return self._overlap_token(chunk)
else:
raise ValueError(f"Unsupported chunk type: {type(chunk)}")

def refine(self, chunks: List[Chunk]) -> List[Chunk]:
"""Refine chunks by adding overlap context.

For each chunk after the first, adds context from the previous chunk.
Can optionally update the chunk text to include the context.

Args:
chunks: List of chunks to refine

Returns:
List of refined chunks with added context

Expand All @@ -267,32 +270,36 @@ def refine(self, chunks: List[Chunk]) -> List[Chunk]:
refined_chunks = [chunk.copy() for chunk in chunks]
else:
refined_chunks = chunks

# Process remaining chunks
for i in range(1, len(refined_chunks)):
# Get context from previous chunk
context = self._get_overlap_context(chunks[i-1])
setattr(refined_chunks[i], 'context', context)
context = self._get_overlap_context(chunks[i - 1])
setattr(refined_chunks[i], "context", context)

# Optionally update chunk text to include context
if self.merge_context and context:
refined_chunks[i].text = context.text + refined_chunks[i].text
refined_chunks[i].start_index = context.start_index
# Update token count to include context and space
# Calculate new token count
if hasattr(self, 'tokenizer') and not self.approximate:
if hasattr(self, "tokenizer") and not self.approximate:
# Use exact token count if we have a tokenizer
refined_chunks[i].token_count = len(self.tokenizer.encode(refined_chunks[i].text))
refined_chunks[i].token_count = len(
self.tokenizer.encode(refined_chunks[i].text)
)
else:
# Otherwise use approximate by adding context tokens plus one for space
refined_chunks[i].token_count = refined_chunks[i].token_count + context.token_count + 1
refined_chunks[i].token_count = (
refined_chunks[i].token_count + context.token_count + 1
)

return refined_chunks

@classmethod
def is_available(cls) -> bool:
"""Check if the OverlapRefinery is available.

Always returns True as this refinery has no external dependencies.
"""
return True
return True
Loading