File size: 4,485 Bytes
d26280a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, NamedTuple

import tree_sitter_languages

from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter

from gpt_engineer.tools.experimental.supported_languages import SUPPORTED_LANGUAGES


class CodeSplitter(TextSplitter):
    """Split code using a AST parser."""

    def __init__(
        self,
        language: str,
        chunk_lines: int = 40,
        chunk_lines_overlap: int = 15,
        max_chars: int = 1500,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.language = language
        self.chunk_lines = chunk_lines
        self.chunk_lines_overlap = chunk_lines_overlap
        self.max_chars = max_chars

    def _chunk_node(self, node: Any, text: str, last_end: int = 0) -> List[str]:
        new_chunks = []
        current_chunk = ""
        for child in node.children:
            if child.end_byte - child.start_byte > self.max_chars:
                # Child is too big, recursively chunk the child
                if len(current_chunk) > 0:
                    new_chunks.append(current_chunk)
                current_chunk = ""
                new_chunks.extend(self._chunk_node(child, text, last_end))
            elif (
                len(current_chunk) + child.end_byte - child.start_byte > self.max_chars
            ):
                # Child would make the current chunk too big, so start a new chunk
                new_chunks.append(current_chunk)
                current_chunk = text[last_end : child.end_byte]
            else:
                current_chunk += text[last_end : child.end_byte]
            last_end = child.end_byte
        if len(current_chunk) > 0:
            new_chunks.append(current_chunk)
        return new_chunks

    def split_text(self, text: str) -> List[str]:
        """Split incoming code and return chunks using the AST."""

        try:
            parser = tree_sitter_languages.get_parser(self.language)
        except Exception as e:
            print(
                f"Could not get parser for language {self.language}. Check "
                "https://github.com/grantjenks/py-tree-sitter-languages#license "
                "for a list of valid languages."
            )
            raise e

        tree = parser.parse(bytes(text, "utf-8"))

        if not tree.root_node.children or tree.root_node.children[0].type != "ERROR":
            chunks = [chunk.strip() for chunk in self._chunk_node(tree.root_node, text)]

            return chunks
        else:
            raise ValueError(f"Could not parse code with language {self.language}.")


class SortedDocuments(NamedTuple):
    by_language: Dict[str, List[Document]]
    other: List[Document]


class DocumentChunker:
    def chunk_documents(documents: List[Document]) -> List[Document]:
        chunked_documents = []

        sorted_documents = _sort_documents_by_programming_language_or_other(documents)

        for language, language_documents in sorted_documents.by_language.items():
            code_splitter = CodeSplitter(
                language=language.lower(),
                chunk_lines=40,
                chunk_lines_overlap=15,
                max_chars=1500,
            )

            chunked_documents.extend(code_splitter.split_documents(language_documents))

        # for now only include code files!
        # chunked_documents.extend(sorted_documents.other)

        return chunked_documents


@staticmethod
def _sort_documents_by_programming_language_or_other(
    documents: List[Document],
) -> SortedDocuments:
    docs_to_split = defaultdict(list)
    other_docs = []

    for doc in documents:
        filename = str(doc.metadata.get("filename"))
        extension = Path(filename).suffix
        language_found = False

        for lang in SUPPORTED_LANGUAGES:
            if extension in lang["extensions"]:
                doc.metadata["is_code"] = True
                doc.metadata["code_language"] = lang["name"]
                doc.metadata["code_language_tree_sitter_name"] = lang[
                    "tree_sitter_name"
                ]
                docs_to_split[lang["tree_sitter_name"]].append(doc)
                language_found = True
                break

        if not language_found:
            doc.metadata["isCode"] = False
            other_docs.append(doc)

    return SortedDocuments(by_language=dict(docs_to_split), other=other_docs)