File size: 15,298 Bytes
2a1cdbf
bcb80f2
 
 
 
2a1cdbf
bcb80f2
 
2a1cdbf
bcb80f2
 
 
 
 
 
 
 
 
be94910
bcb80f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be94910
bcb80f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be94910
bcb80f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be94910
bcb80f2
 
 
be94910
bcb80f2
 
 
be94910
bcb80f2
 
 
be94910
bcb80f2
 
 
 
 
be94910
bcb80f2
 
 
be94910
bcb80f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be94910
bcb80f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be94910
bcb80f2
 
 
 
be94910
bcb80f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import re
import time

# Model constants
CODET5_MODEL = "Salesforce/codet5-base-multi-sum"

class CodeT5Summarizer:
    def __init__(self, device=None):
        """Initialize CodeT5 summarization model."""
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Initialize model and tokenizer
        with st.spinner("Loading CodeT5 model... this may take a minute..."):
            self.tokenizer = AutoTokenizer.from_pretrained(CODET5_MODEL)
            self.model = AutoModelForSeq2SeqLM.from_pretrained(CODET5_MODEL).to(self.device)
    
    def preprocess_code(self, code):
        """Clean and preprocess the Python code."""
        # Remove empty lines
        code = re.sub(r'\n\s*\n', '\n', code)
        
        # Remove excessive comments (keeping docstrings)
        code_lines = []
        in_docstring = False
        docstring_delimiter = None
        
        for line in code.split('\n'):
            # Check for docstring delimiters
            if '"""' in line or "'''" in line:
                delimiter = '"""' if '"""' in line else "'''"
                if not in_docstring:
                    in_docstring = True
                    docstring_delimiter = delimiter
                elif docstring_delimiter == delimiter:
                    in_docstring = False
                    docstring_delimiter = None
            
            # Keep docstrings and non-comment lines
            if in_docstring or not line.strip().startswith('#'):
                code_lines.append(line)
        
        processed_code = '\n'.join(code_lines)
        
        # Normalize whitespace
        processed_code = re.sub(r' +', ' ', processed_code)
        
        return processed_code
    
    def extract_functions(self, code):
        """Extract individual functions for summarization"""
        # Simple regex to find function definitions
        function_pattern = r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(.*?\).*?:'
        function_matches = re.finditer(function_pattern, code, re.DOTALL)
        
        functions = []
        for match in function_matches:
            start_pos = match.start()
            # Find the function body
            function_name = match.group(1)
            lines = code[start_pos:].split('\n')
            
            # Skip the function definition line
            body_start = 1
            while body_start < len(lines) and not lines[body_start].strip():
                body_start += 1
                
            if body_start < len(lines):
                # Get the indentation of the function body
                body_indent = len(lines[body_start]) - len(lines[body_start].lstrip())
                
                # Gather all lines with at least this indentation
                function_body = [lines[0]]  # The function definition
                i = 1
                while i < len(lines):
                    line = lines[i]
                    if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
                        break
                    function_body.append(line)
                    i += 1
                
                function_code = '\n'.join(function_body)
                functions.append((function_name, function_code))
        
        # Simple regex to find class methods
        class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
        class_matches = re.finditer(class_pattern, code, re.DOTALL)
        
        for match in class_matches:
            class_name = match.group(1)
            start_pos = match.start()
            
            # Find class methods using the function pattern
            class_code = code[start_pos:]
            method_matches = re.finditer(function_pattern, class_code, re.DOTALL)
            
            for method_match in method_matches:
                method_name = method_match.group(1)
                # Skip if this is not a method (i.e., it's a function outside the class)
                if method_match.start() > 200:  # Simple heuristic to check if method is within class scope
                    break
                
                # Get the full method code
                method_start = method_match.start()
                method_lines = class_code[method_start:].split('\n')
                
                # Skip the method definition line
                body_start = 1
                while body_start < len(method_lines) and not method_lines[body_start].strip():
                    body_start += 1
                
                if body_start < len(method_lines):
                    # Get the indentation of the method body
                    body_indent = len(method_lines[body_start]) - len(method_lines[body_start].lstrip())
                    
                    # Gather all lines with at least this indentation
                    method_body = [method_lines[0]]  # The method definition
                    i = 1
                    while i < len(method_lines):
                        line = method_lines[i]
                        if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
                            break
                        method_body.append(line)
                        i += 1
                    
                    method_code = '\n'.join(method_body)
                    functions.append((f"{class_name}.{method_name}", method_code))
        
        return functions
    
    def extract_classes(self, code):
        """Extract class definitions for summarization"""
        class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
        class_matches = re.finditer(class_pattern, code, re.DOTALL)
        
        classes = []
        for match in class_matches:
            class_name = match.group(1)
            start_pos = match.start()
            
            # Extract class body
            class_lines = code[start_pos:].split('\n')
            
            # Skip the class definition line
            body_start = 1
            while body_start < len(class_lines) and not class_lines[body_start].strip():
                body_start += 1
            
            if body_start < len(class_lines):
                # Get the indentation of the class body
                body_indent = len(class_lines[body_start]) - len(class_lines[body_start].lstrip())
                
                # Gather all lines with at least this indentation
                class_body = [class_lines[0]]  # The class definition
                i = 1
                while i < len(class_lines):
                    line = class_lines[i]
                    if line.strip() and (len(line) - len(line.lstrip())) < body_indent:
                        break
                    class_body.append(line)
                    i += 1
                
                class_code = '\n'.join(class_body)
                classes.append((class_name, class_code))
        
        return classes
    
    def summarize(self, code, max_length=50):
        """Generate summary using CodeT5."""
        # Truncate input if needed
        max_input_length = 512  # CodeT5 typically accepts up to 512 tokens
        tokenized_code = self.tokenizer(code, truncation=True, max_length=max_input_length, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            generated_ids = self.model.generate(
                tokenized_code["input_ids"],
                max_length=max_length,
                num_beams=4,
                early_stopping=True
            )
        
        summary = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        return summary
    
    def summarize_code(self, code, summarize_functions=True, summarize_classes=True):
        """
        Generate full file summary and optionally function/class level summaries.
        Returns a dictionary with summaries.
        """
        preprocessed_code = self.preprocess_code(code)
        
        results = {
            "file_summary": None,
            "function_summaries": {},
            "class_summaries": {}
        }
        
        # Generate file-level summary
        try:
            file_summary = self.summarize(preprocessed_code)
            results["file_summary"] = file_summary
        except Exception as e:
            results["file_summary"] = f"Error generating file summary: {str(e)}"
        
        # Generate function-level summaries if requested
        if summarize_functions:
            functions = self.extract_functions(preprocessed_code)
            
            for function_name, function_code in functions:
                try:
                    summary = self.summarize(function_code)
                    results["function_summaries"][function_name] = summary
                except Exception as e:
                    results["function_summaries"][function_name] = f"Error: {str(e)}"
        
        # Generate class-level summaries if requested
        if summarize_classes:
            classes = self.extract_classes(preprocessed_code)
            
            for class_name, class_code in classes:
                try:
                    summary = self.summarize(class_code)
                    results["class_summaries"][class_name] = summary
                except Exception as e:
                    results["class_summaries"][class_name] = f"Error: {str(e)}"
        
        return results

def main():
    st.set_page_config(
        page_title="Python Code Summarizer",
        page_icon="πŸ“",
        layout="wide"
    )
    
    st.title("πŸ“ Python Code Summarizer using CodeT5")
    st.markdown("""
    Upload a Python file or paste code directly to generate summaries.
    This app uses CodeT5, a pretrained model for code understanding and generation.
    """)
    
    # Initialize session state
    if 'summarizer' not in st.session_state:
        st.session_state.summarizer = None
    
    # Load model if not already loaded
    if st.session_state.summarizer is None:
        st.session_state.summarizer = CodeT5Summarizer()
    
    # Create tabs for different input methods
    tab1, tab2 = st.tabs(["Upload Python File", "Paste Code"])
    
    with tab1:
        uploaded_file = st.file_uploader("Choose a Python file", type=['py'])
        if uploaded_file is not None:
            code = uploaded_file.getvalue().decode('utf-8')
            with st.expander("View Uploaded Code", expanded=False):
                st.code(code, language='python')
            
            # Add summarization options
            st.subheader("Summarization Options")
            col1, col2 = st.columns(2)
            with col1:
                summarize_functions = st.checkbox("Generate function summaries", value=True)
            with col2:
                summarize_classes = st.checkbox("Generate class summaries", value=True)
            
            if st.button("Summarize Code", key="summarize_file"):
                with st.spinner("Generating summaries..."):
                    start_time = time.time()
                    summaries = st.session_state.summarizer.summarize_code(
                        code, 
                        summarize_functions=summarize_functions,
                        summarize_classes=summarize_classes
                    )
                    end_time = time.time()
                    
                    # Display summaries
                    st.success(f"Summarization completed in {end_time - start_time:.2f} seconds!")
                    
                    # File summary
                    st.subheader("File Summary")
                    st.write(summaries["file_summary"])
                    
                    # Function summaries
                    if summarize_functions and summaries["function_summaries"]:
                        st.subheader("Function Summaries")
                        for func_name, summary in summaries["function_summaries"].items():
                            with st.expander(f"Function: {func_name}"):
                                st.write(summary)
                    
                    # Class summaries
                    if summarize_classes and summaries["class_summaries"]:
                        st.subheader("Class Summaries")
                        for class_name, summary in summaries["class_summaries"].items():
                            with st.expander(f"Class: {class_name}"):
                                st.write(summary)
    
    with tab2:
        code = st.text_area("Paste Python code here", height=300)
        if code:
            # Add summarization options
            st.subheader("Summarization Options")
            col1, col2 = st.columns(2)
            with col1:
                summarize_functions = st.checkbox("Generate function summaries", value=True, key="func_paste")
            with col2:
                summarize_classes = st.checkbox("Generate class summaries", value=True, key="class_paste")
            
            if st.button("Summarize Code", key="summarize_paste"):
                with st.spinner("Generating summaries..."):
                    start_time = time.time()
                    summaries = st.session_state.summarizer.summarize_code(
                        code, 
                        summarize_functions=summarize_functions,
                        summarize_classes=summarize_classes
                    )
                    end_time = time.time()
                    
                    # Display summaries
                    st.success(f"Summarization completed in {end_time - start_time:.2f} seconds!")
                    
                    # File summary
                    st.subheader("File Summary")
                    st.write(summaries["file_summary"])
                    
                    # Function summaries
                    if summarize_functions and summaries["function_summaries"]:
                        st.subheader("Function Summaries")
                        for func_name, summary in summaries["function_summaries"].items():
                            with st.expander(f"Function: {func_name}"):
                                st.write(summary)
                    
                    # Class summaries
                    if summarize_classes and summaries["class_summaries"]:
                        st.subheader("Class Summaries")
                        for class_name, summary in summaries["class_summaries"].items():
                            with st.expander(f"Class: {class_name}"):
                                st.write(summary)
    
    st.markdown("---")
    st.markdown("### About")
    st.markdown("""
    This app uses the CodeT5 model to generate summaries of Python code. The model is trained on a large corpus of code and documentation.
    
    **Features:**
    - File-level summaries
    - Function-level summaries
    - Class-level summaries
    
    **Limitations:**
    - Summaries may not always be accurate
    - Long files may be truncated
    - Complex code structures might not be properly understood
    """)

if __name__ == "__main__":
    main()