ccccccccc / app.py
rahideer's picture
Update app.py
9cf9102 verified
import streamlit as st
import torch
import javalang
import re
import os
import tempfile
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
import os
import zipfile
# Check and unzip dataset if not already unzipped
dataset_folder = "Subject_CloneTypes_Directories"
if not os.path.exists(dataset_folder):
with zipfile.ZipFile("Subject_CloneTypes_Directories.zip", 'r') as zip_ref:
zip_ref.extractall(dataset_folder)
print("βœ… Dataset extracted!")
else:
print("βœ… Dataset already extracted!")
# Configuration
MAX_FILE_SIZE = 5000
MAX_AST_DEPTH = 50
EMBEDDING_DIM = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Code Normalization
def normalize_code(code):
code = re.sub(r'//.*', '', code)
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
code = re.sub(r'"[^"]*"', '"STRING"', code)
code = re.sub(r'\s+', ' ', code).strip()
return code
# AST Extraction
def parse_java(code):
try:
tokens = javalang.tokenizer.tokenize(code)
parser = javalang.parser.Parser(tokens)
return parser.parse()
except:
return None
# AST Processor
class ASTProcessor:
def __init__(self):
self.node_types = set()
def extract_paths(self, node, max_depth=MAX_AST_DEPTH):
paths = []
self._dfs(node, [], paths, 0, max_depth)
return paths
def _dfs(self, node, current_path, paths, depth, max_depth):
if depth > max_depth:
return
node_type = type(node).__name__
current_path.append(node_type)
if not hasattr(node, 'children') or depth == max_depth:
paths.append(current_path.copy())
current_path.pop()
return
for child in node.children:
if isinstance(child, (javalang.ast.Node, list, tuple)):
if isinstance(child, (list, tuple)):
for c in child:
if isinstance(c, javalang.ast.Node):
self._dfs(c, current_path, paths, depth + 1, max_depth)
else:
self._dfs(child, current_path, paths, depth + 1, max_depth)
current_path.pop()
# Model
class ASTEncoder(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, embedding_dim, batch_first=True)
def forward(self, paths):
embedded = self.embedding(paths)
_, (hidden, _) = self.lstm(embedded)
return hidden[-1]
class CodeBERTEncoder(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained('microsoft/codebert-base')
self.tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')
def forward(self, code):
inputs = self.tokenizer(code, return_tensors='pt', truncation=True, padding=True)
outputs = self.bert(**inputs)
return outputs.last_hidden_state.mean(dim=1)
class HybridCloneDetector(nn.Module):
def __init__(self, ast_vocab_size):
super().__init__()
self.ast_encoder = ASTEncoder(ast_vocab_size, EMBEDDING_DIM)
self.code_encoder = CodeBERTEncoder()
self.classifier = nn.Sequential(
nn.Linear(EMBEDDING_DIM * 2, EMBEDDING_DIM),
nn.ReLU(),
nn.Linear(EMBEDDING_DIM, 2)
)
def forward(self, ast1, code1, ast2, code2):
ast_emb1 = self.ast_encoder(ast1)
ast_emb2 = self.ast_encoder(ast2)
code_emb1 = self.code_encoder(code1)
code_emb2 = self.code_encoder(code2)
diff_ast = torch.abs(ast_emb1 - ast_emb2)
diff_code = torch.abs(code_emb1 - code_emb2)
combined = torch.cat([diff_ast, diff_code], dim=1)
return self.classifier(combined)
# Streamlit UI
st.title("Java Code Clone Detector")
uploaded_file1 = st.file_uploader("Upload Java File 1", type=["java"])
uploaded_file2 = st.file_uploader("Upload Java File 2", type=["java"])
if uploaded_file1 and uploaded_file2:
code1 = uploaded_file1.read().decode('utf-8')
code2 = uploaded_file2.read().decode('utf-8')
# Normalize code
norm_code1 = normalize_code(code1)
norm_code2 = normalize_code(code2)
# Parse AST
ast1 = parse_java(norm_code1)
ast2 = parse_java(norm_code2)
if ast1 is None or ast2 is None:
st.error("Failed to parse one of the files. Please upload proper Java code.")
else:
st.success("Files parsed successfully.")
# Inference (placeholder)
st.write("πŸ”§ **Model loading...** (currently using placeholder)")
# In a real app you would load your trained model here
st.warning("Model inference not available yet in this simple demo.")
st.write("βœ… Code normalization done.")
st.code(norm_code1[:500], language='java')
st.code(norm_code2[:500], language='java')
st.info("Clone detection: [Placeholder] Results will appear here after training integration.")
else:
st.info("Upload two Java files to start clone detection.")