File size: 4,485 Bytes
d26280a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, NamedTuple
import tree_sitter_languages
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter
from gpt_engineer.tools.experimental.supported_languages import SUPPORTED_LANGUAGES
class CodeSplitter(TextSplitter):
"""Split code using a AST parser."""
def __init__(
self,
language: str,
chunk_lines: int = 40,
chunk_lines_overlap: int = 15,
max_chars: int = 1500,
**kwargs,
):
super().__init__(**kwargs)
self.language = language
self.chunk_lines = chunk_lines
self.chunk_lines_overlap = chunk_lines_overlap
self.max_chars = max_chars
def _chunk_node(self, node: Any, text: str, last_end: int = 0) -> List[str]:
new_chunks = []
current_chunk = ""
for child in node.children:
if child.end_byte - child.start_byte > self.max_chars:
# Child is too big, recursively chunk the child
if len(current_chunk) > 0:
new_chunks.append(current_chunk)
current_chunk = ""
new_chunks.extend(self._chunk_node(child, text, last_end))
elif (
len(current_chunk) + child.end_byte - child.start_byte > self.max_chars
):
# Child would make the current chunk too big, so start a new chunk
new_chunks.append(current_chunk)
current_chunk = text[last_end : child.end_byte]
else:
current_chunk += text[last_end : child.end_byte]
last_end = child.end_byte
if len(current_chunk) > 0:
new_chunks.append(current_chunk)
return new_chunks
def split_text(self, text: str) -> List[str]:
"""Split incoming code and return chunks using the AST."""
try:
parser = tree_sitter_languages.get_parser(self.language)
except Exception as e:
print(
f"Could not get parser for language {self.language}. Check "
"https://github.com/grantjenks/py-tree-sitter-languages#license "
"for a list of valid languages."
)
raise e
tree = parser.parse(bytes(text, "utf-8"))
if not tree.root_node.children or tree.root_node.children[0].type != "ERROR":
chunks = [chunk.strip() for chunk in self._chunk_node(tree.root_node, text)]
return chunks
else:
raise ValueError(f"Could not parse code with language {self.language}.")
class SortedDocuments(NamedTuple):
by_language: Dict[str, List[Document]]
other: List[Document]
class DocumentChunker:
def chunk_documents(documents: List[Document]) -> List[Document]:
chunked_documents = []
sorted_documents = _sort_documents_by_programming_language_or_other(documents)
for language, language_documents in sorted_documents.by_language.items():
code_splitter = CodeSplitter(
language=language.lower(),
chunk_lines=40,
chunk_lines_overlap=15,
max_chars=1500,
)
chunked_documents.extend(code_splitter.split_documents(language_documents))
# for now only include code files!
# chunked_documents.extend(sorted_documents.other)
return chunked_documents
@staticmethod
def _sort_documents_by_programming_language_or_other(
documents: List[Document],
) -> SortedDocuments:
docs_to_split = defaultdict(list)
other_docs = []
for doc in documents:
filename = str(doc.metadata.get("filename"))
extension = Path(filename).suffix
language_found = False
for lang in SUPPORTED_LANGUAGES:
if extension in lang["extensions"]:
doc.metadata["is_code"] = True
doc.metadata["code_language"] = lang["name"]
doc.metadata["code_language_tree_sitter_name"] = lang[
"tree_sitter_name"
]
docs_to_split[lang["tree_sitter_name"]].append(doc)
language_found = True
break
if not language_found:
doc.metadata["isCode"] = False
other_docs.append(doc)
return SortedDocuments(by_language=dict(docs_to_split), other=other_docs)
|