|
|
|
function initializeDragAndDrop() { |
|
const nodeItems = document.querySelectorAll('.node-item'); |
|
const canvas = document.getElementById('network-canvas'); |
|
let draggedNode = null; |
|
let offsetX, offsetY; |
|
let isDragging = false; |
|
let isConnecting = false; |
|
let startNode = null; |
|
let connectionLine = null; |
|
let nodeCounter = {}; |
|
|
|
|
|
const recentlyCreated = { |
|
nodeIds: new Set(), |
|
dragStartTime: 0, |
|
isDropHandled: false, |
|
inProgress: false, |
|
timestamp: 0 |
|
}; |
|
|
|
|
|
let networkLayers = { |
|
layers: [], |
|
connections: [] |
|
}; |
|
|
|
|
|
function formatNumber(num) { |
|
if (num === 0) return '0'; |
|
if (!num) return 'N/A'; |
|
|
|
if (num >= 1e9) return (num / 1e9).toFixed(2) + 'B'; |
|
if (num >= 1e6) return (num / 1e6).toFixed(2) + 'M'; |
|
if (num >= 1e3) return (num / 1e3).toFixed(2) + 'K'; |
|
return num.toString(); |
|
} |
|
|
|
|
|
nodeItems.forEach(item => { |
|
|
|
item.addEventListener('dragstart', function(e) { |
|
|
|
recentlyCreated.isDropHandled = false; |
|
recentlyCreated.inProgress = true; |
|
recentlyCreated.dragStartTime = Date.now(); |
|
|
|
const nodeType = this.getAttribute('data-type'); |
|
|
|
|
|
e.dataTransfer.setData('text/plain', nodeType); |
|
e.dataTransfer.setData('application/x-neural-node-type', nodeType); |
|
|
|
|
|
try { |
|
e.dataTransfer.nodeType = nodeType; |
|
e.dataTransfer._neural_type = nodeType; |
|
} catch (err) { |
|
|
|
} |
|
|
|
draggedNode = this; |
|
|
|
|
|
const ghost = this.cloneNode(true); |
|
ghost.style.opacity = '0.5'; |
|
document.body.appendChild(ghost); |
|
e.dataTransfer.setDragImage(ghost, 0, 0); |
|
setTimeout(() => { |
|
document.body.removeChild(ghost); |
|
}, 0); |
|
|
|
|
|
const dragEndHandler = function() { |
|
setTimeout(() => { |
|
recentlyCreated.inProgress = false; |
|
draggedNode = null; |
|
}, 100); |
|
|
|
item.removeEventListener('dragend', dragEndHandler); |
|
}; |
|
|
|
|
|
item.addEventListener('dragend', dragEndHandler); |
|
}); |
|
}); |
|
|
|
|
|
function handleDragOver(e) { |
|
e.preventDefault(); |
|
e.dataTransfer.dropEffect = 'copy'; |
|
} |
|
|
|
|
|
canvas.addEventListener('dragover', handleDragOver); |
|
|
|
|
|
canvas.addEventListener('drop', function dropHandler(e) { |
|
e.preventDefault(); |
|
|
|
|
|
if (recentlyCreated.isDropHandled) { |
|
return; |
|
} |
|
|
|
const now = Date.now(); |
|
|
|
|
|
if (now - recentlyCreated.timestamp < 500) { |
|
return; |
|
} |
|
|
|
|
|
recentlyCreated.isDropHandled = true; |
|
recentlyCreated.timestamp = now; |
|
|
|
|
|
if (!recentlyCreated.inProgress || !draggedNode || !draggedNode.classList.contains('node-item')) { |
|
return; |
|
} |
|
|
|
|
|
let nodeType = null; |
|
try { |
|
|
|
nodeType = e.dataTransfer.getData('text/plain'); |
|
|
|
|
|
if (!nodeType) { |
|
nodeType = e.dataTransfer.getData('application/x-neural-node-type'); |
|
} |
|
if (!nodeType && e.dataTransfer.nodeType) { |
|
nodeType = e.dataTransfer.nodeType; |
|
} |
|
if (!nodeType && e.dataTransfer._neural_type) { |
|
nodeType = e.dataTransfer._neural_type; |
|
} |
|
if (!nodeType && draggedNode) { |
|
nodeType = draggedNode.getAttribute('data-type'); |
|
} |
|
} catch (err) { |
|
|
|
} |
|
|
|
if (!nodeType) { |
|
return; |
|
} |
|
|
|
|
|
const canvasRect = canvas.getBoundingClientRect(); |
|
const x = e.clientX - canvasRect.left - 75; |
|
const y = e.clientY - canvasRect.top - 30; |
|
|
|
|
|
const posX = Math.max(0, Math.min(canvasRect.width - 150, x)); |
|
const posY = Math.max(0, Math.min(canvasRect.height - 100, y)); |
|
|
|
|
|
const layerId = `${nodeType}-${Date.now()}-${Math.floor(Math.random() * 10000)}`; |
|
|
|
|
|
if (recentlyCreated.nodeIds.has(layerId)) { |
|
return; |
|
} |
|
recentlyCreated.nodeIds.add(layerId); |
|
|
|
|
|
if (recentlyCreated.nodeIds.size > 10) { |
|
const iterator = recentlyCreated.nodeIds.values(); |
|
recentlyCreated.nodeIds.delete(iterator.next().value); |
|
} |
|
|
|
|
|
nodeCounter[nodeType] = (nodeCounter[nodeType] || 0) + 1; |
|
|
|
|
|
const layerId = `${nodeType}-${Date.now()}-${Math.floor(Math.random() * 1000)}`; |
|
|
|
|
|
const canvasNode = document.createElement('div'); |
|
canvasNode.className = `canvas-node ${nodeType}-node`; |
|
canvasNode.setAttribute('data-type', nodeType); |
|
canvasNode.setAttribute('data-id', layerId); |
|
canvasNode.style.position = 'absolute'; |
|
canvasNode.style.left = `${posX}px`; |
|
canvasNode.style.top = `${posY}px`; |
|
|
|
|
|
const nodeConfig = window.neuralNetwork.createNodeConfig(nodeType); |
|
|
|
|
|
let nodeName, inputShape, outputShape, parameters; |
|
|
|
switch(nodeType) { |
|
case 'input': |
|
nodeName = 'Input Layer'; |
|
inputShape = 'N/A'; |
|
outputShape = '[' + nodeConfig.shape.join(' × ') + ']'; |
|
parameters = nodeConfig.parameters; |
|
break; |
|
case 'hidden': |
|
const hiddenCount = document.querySelectorAll('.canvas-node[data-type="hidden"]').length; |
|
nodeConfig.units = hiddenCount === 0 ? 128 : 64; |
|
nodeName = `Hidden Layer ${hiddenCount + 1}`; |
|
|
|
inputShape = 'Connect input'; |
|
outputShape = `[${nodeConfig.units}]`; |
|
parameters = 'Connect input to calculate'; |
|
break; |
|
case 'output': |
|
nodeName = 'Output Layer'; |
|
inputShape = 'Connect input'; |
|
outputShape = `[${nodeConfig.units}]`; |
|
parameters = 'Connect input to calculate'; |
|
break; |
|
case 'conv': |
|
const convCount = document.querySelectorAll('.canvas-node[data-type="conv"]').length; |
|
nodeConfig.filters = 32 * (convCount + 1); |
|
nodeName = `Conv2D ${convCount + 1}`; |
|
inputShape = 'Connect input'; |
|
outputShape = 'Depends on input'; |
|
|
|
parameters = `Kernel: ${nodeConfig.kernelSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`; |
|
break; |
|
case 'pool': |
|
const poolCount = document.querySelectorAll('.canvas-node[data-type="pool"]').length; |
|
nodeName = `Pooling ${poolCount + 1}`; |
|
inputShape = 'Connect input'; |
|
outputShape = 'Depends on input'; |
|
parameters = `Pool size: ${nodeConfig.poolSize.join('×')}\nStride: ${nodeConfig.strides.join('×')}\nPadding: ${nodeConfig.padding}`; |
|
break; |
|
default: |
|
nodeName = 'Unknown Layer'; |
|
inputShape = 'N/A'; |
|
outputShape = 'N/A'; |
|
parameters = 'N/A'; |
|
} |
|
|
|
|
|
const nodeContent = document.createElement('div'); |
|
nodeContent.className = 'node-content'; |
|
|
|
|
|
const shapeInfo = document.createElement('div'); |
|
shapeInfo.className = 'shape-info'; |
|
shapeInfo.innerHTML = ` |
|
<div class="shape-row"><span class="shape-label">Input:</span> <span class="input-shape">${inputShape}</span></div> |
|
<div class="shape-row"><span class="shape-label">Output:</span> <span class="output-shape">${outputShape}</span></div> |
|
`; |
|
|
|
|
|
const paramsSection = document.createElement('div'); |
|
paramsSection.className = 'params-section'; |
|
paramsSection.innerHTML = ` |
|
<div class="params-details">${parameters}</div> |
|
<div class="node-parameters">Params: ${nodeConfig.parameters !== undefined ? formatNumber(nodeConfig.parameters) : '?'}</div> |
|
`; |
|
|
|
|
|
nodeContent.appendChild(shapeInfo); |
|
nodeContent.appendChild(paramsSection); |
|
|
|
|
|
const dimensionsSection = document.createElement('div'); |
|
dimensionsSection.className = 'node-dimensions'; |
|
|
|
|
|
let dimensionsText = ''; |
|
switch(nodeType) { |
|
case 'input': |
|
dimensionsText = nodeConfig.shape.join(' × '); |
|
break; |
|
case 'hidden': |
|
case 'output': |
|
dimensionsText = nodeConfig.units.toString(); |
|
break; |
|
case 'conv': |
|
if (nodeConfig.inputShape && nodeConfig.outputShape) { |
|
dimensionsText = `${nodeConfig.inputShape.join('×')} → ${nodeConfig.outputShape.join('×')}`; |
|
} else { |
|
dimensionsText = `? → ${nodeConfig.filters} filters`; |
|
} |
|
break; |
|
case 'pool': |
|
if (nodeConfig.inputShape && nodeConfig.outputShape) { |
|
dimensionsText = `${nodeConfig.inputShape.join('×')} → ${nodeConfig.outputShape.join('×')}`; |
|
} else { |
|
dimensionsText = `? → ?`; |
|
} |
|
break; |
|
case 'linear': |
|
dimensionsText = `${nodeConfig.inputFeatures} → ${nodeConfig.outputFeatures}`; |
|
break; |
|
} |
|
dimensionsSection.textContent = dimensionsText; |
|
|
|
|
|
const nodeTitle = document.createElement('div'); |
|
nodeTitle.className = 'node-title'; |
|
nodeTitle.textContent = nodeName; |
|
|
|
|
|
const nodeControls = document.createElement('div'); |
|
nodeControls.className = 'node-controls'; |
|
|
|
const editButton = document.createElement('button'); |
|
editButton.className = 'node-edit-btn'; |
|
editButton.innerHTML = '✎'; |
|
editButton.title = 'Edit Layer'; |
|
|
|
const deleteButton = document.createElement('button'); |
|
deleteButton.className = 'node-delete-btn'; |
|
deleteButton.innerHTML = '×'; |
|
deleteButton.title = 'Delete Layer'; |
|
|
|
nodeControls.appendChild(editButton); |
|
nodeControls.appendChild(deleteButton); |
|
|
|
|
|
const portIn = document.createElement('div'); |
|
portIn.className = 'node-port port-in'; |
|
|
|
const portOut = document.createElement('div'); |
|
portOut.className = 'node-port port-out'; |
|
|
|
|
|
canvasNode.appendChild(nodeTitle); |
|
canvasNode.appendChild(nodeControls); |
|
canvasNode.appendChild(dimensionsSection); |
|
canvasNode.appendChild(nodeContent); |
|
canvasNode.appendChild(portIn); |
|
canvasNode.appendChild(portOut); |
|
|
|
|
|
canvasNode.setAttribute('data-name', nodeName); |
|
canvasNode.setAttribute('data-dimensions', dimensionsText); |
|
|
|
|
|
canvas.appendChild(canvasNode); |
|
|
|
|
|
canvasNode.layerConfig = nodeConfig; |
|
|
|
|
|
canvasNode.addEventListener('mousedown', startDrag); |
|
|
|
|
|
portIn.addEventListener('mousedown', (e) => { |
|
e.stopPropagation(); |
|
}); |
|
|
|
portOut.addEventListener('mousedown', (e) => { |
|
e.stopPropagation(); |
|
startConnection(canvasNode, e); |
|
}); |
|
|
|
|
|
canvasNode.addEventListener('dblclick', () => { |
|
openLayerEditor(canvasNode); |
|
}); |
|
|
|
|
|
canvasNode.addEventListener('contextmenu', (e) => { |
|
e.preventDefault(); |
|
deleteNode(canvasNode); |
|
}); |
|
|
|
|
|
editButton.addEventListener('click', (e) => { |
|
e.stopPropagation(); |
|
openLayerEditor(canvasNode); |
|
}); |
|
|
|
|
|
deleteButton.addEventListener('click', (e) => { |
|
e.stopPropagation(); |
|
deleteNode(canvasNode); |
|
}); |
|
|
|
|
|
networkLayers.layers.push({ |
|
id: layerId, |
|
type: nodeType, |
|
name: nodeName, |
|
position: { x: posX, y: posY }, |
|
dimensions: dimensionsText, |
|
config: nodeConfig, |
|
parameters: nodeConfig.parameters || 0 |
|
}); |
|
|
|
|
|
document.dispatchEvent(new CustomEvent('networkUpdated', { |
|
detail: networkLayers |
|
})); |
|
|
|
updateConnections(); |
|
|
|
|
|
const canvasHint = document.querySelector('.canvas-hint'); |
|
if (canvasHint) { |
|
canvasHint.style.display = 'none'; |
|
} |
|
|
|
|
|
draggedNode = null; |
|
recentlyCreated.inProgress = false; |
|
|
|
|
|
setTimeout(() => { |
|
if (window.draggedNode) { |
|
delete window.draggedNode; |
|
} |
|
recentlyCreated.isDropHandled = false; |
|
}, 100); |
|
} |
|
|
|
|
|
function deleteNode(node) { |
|
if (!node) return; |
|
|
|
const nodeId = node.getAttribute('data-id'); |
|
|
|
|
|
const connections = document.querySelectorAll(`.connection[data-source="${nodeId}"], .connection[data-target="${nodeId}"]`); |
|
|
|
connections.forEach(connection => { |
|
if (connection.parentNode) { |
|
connection.parentNode.removeChild(connection); |
|
} |
|
}); |
|
|
|
|
|
networkLayers.connections = networkLayers.connections.filter(conn => |
|
conn.source !== nodeId && conn.target !== nodeId |
|
); |
|
|
|
|
|
const layerIndex = networkLayers.layers.findIndex(layer => layer.id === nodeId); |
|
if (layerIndex !== -1) { |
|
networkLayers.layers.splice(layerIndex, 1); |
|
} |
|
|
|
|
|
if (node.parentNode) { |
|
node.parentNode.removeChild(node); |
|
} |
|
|
|
|
|
if (document.querySelectorAll('.canvas-node').length === 0) { |
|
const canvasHint = document.querySelector('.canvas-hint'); |
|
if (canvasHint) { |
|
canvasHint.style.display = 'block'; |
|
} |
|
} |
|
|
|
|
|
updateConnections(); |
|
|
|
|
|
document.dispatchEvent(new CustomEvent('networkUpdated', { |
|
detail: networkLayers |
|
})); |
|
} |
|
|
|
|
|
function startDrag(e) { |
|
console.log('[DEBUG] startDrag called', e.target); |
|
|
|
if (isConnecting) return; |
|
|
|
|
|
if (e.target.closest('.node-controls') || e.target.closest('.node-port')) { |
|
return; |
|
} |
|
|
|
isDragging = true; |
|
|
|
const target = e.target.closest('.canvas-node'); |
|
if (!target) { |
|
console.error('[ERROR] No canvas-node found in startDrag'); |
|
return; |
|
} |
|
|
|
const rect = target.getBoundingClientRect(); |
|
|
|
|
|
offsetX = e.clientX - rect.left; |
|
offsetY = e.clientY - rect.top; |
|
|
|
|
|
document.addEventListener('mousemove', dragNode); |
|
document.addEventListener('mouseup', stopDrag); |
|
|
|
|
|
draggedNode = target; |
|
|
|
|
|
draggedNode.style.zIndex = "100"; |
|
|
|
|
|
draggedNode.classList.add('dragging'); |
|
|
|
|
|
document.body.classList.add('node-dragging'); |
|
|
|
|
|
e.preventDefault(); |
|
|
|
console.log(`[DEBUG] Started dragging node: ${target.getAttribute('data-id')}`); |
|
} |
|
|
|
|
|
function dragNode(e) { |
|
if (!isDragging || !draggedNode) { |
|
console.log('[WARN] dragNode called but not in dragging state'); |
|
return; |
|
} |
|
|
|
const canvasRect = canvas.getBoundingClientRect(); |
|
let x = e.clientX - canvasRect.left - offsetX; |
|
let y = e.clientY - canvasRect.top - offsetY; |
|
|
|
|
|
const nodeWidth = draggedNode.offsetWidth || 150; |
|
const nodeHeight = draggedNode.offsetHeight || 100; |
|
|
|
|
|
x = Math.max(0, Math.min(canvasRect.width - nodeWidth, x)); |
|
y = Math.max(0, Math.min(canvasRect.height - nodeHeight, y)); |
|
|
|
|
|
draggedNode.style.position = 'absolute'; |
|
draggedNode.style.left = `${x}px`; |
|
draggedNode.style.top = `${y}px`; |
|
draggedNode.style.width = `${nodeWidth}px`; |
|
|
|
|
|
const nodeId = draggedNode.getAttribute('data-id'); |
|
const layerIndex = networkLayers.layers.findIndex(layer => layer.id === nodeId); |
|
if (layerIndex !== -1) { |
|
networkLayers.layers[layerIndex].position = { x, y }; |
|
} |
|
|
|
|
|
updateConnections(); |
|
} |
|
|
|
|
|
function stopDrag(e) { |
|
if (!isDragging) { |
|
return; |
|
} |
|
|
|
console.log('[DEBUG] stopDrag called'); |
|
|
|
|
|
document.removeEventListener('mousemove', dragNode); |
|
document.removeEventListener('mouseup', stopDrag); |
|
|
|
isDragging = false; |
|
|
|
|
|
document.body.classList.remove('node-dragging'); |
|
|
|
|
|
if (draggedNode) { |
|
draggedNode.style.zIndex = "10"; |
|
draggedNode.classList.remove('dragging'); |
|
|
|
|
|
updateConnections(); |
|
|
|
|
|
const nodeId = draggedNode.getAttribute('data-id'); |
|
console.log(`[DEBUG] Stopped dragging node: ${nodeId}`); |
|
draggedNode = null; |
|
} |
|
} |
|
|
|
|
|
function startConnection(node, e) { |
|
isConnecting = true; |
|
startNode = node; |
|
|
|
|
|
connectionLine = document.createElement('div'); |
|
connectionLine.className = 'connection temp-connection'; |
|
|
|
|
|
const portOut = node.querySelector('.node-port.port-out'); |
|
const portRect = portOut.getBoundingClientRect(); |
|
const canvasRect = canvas.getBoundingClientRect(); |
|
|
|
const startX = portRect.left + portRect.width / 2 - canvasRect.left; |
|
const startY = portRect.top + portRect.height / 2 - canvasRect.top; |
|
|
|
|
|
connectionLine.style.left = `${startX}px`; |
|
connectionLine.style.top = `${startY}px`; |
|
connectionLine.style.width = '0px'; |
|
connectionLine.style.transform = 'rotate(0deg)'; |
|
|
|
|
|
portOut.classList.add('active-port'); |
|
|
|
|
|
highlightValidConnectionTargets(node); |
|
|
|
canvas.appendChild(connectionLine); |
|
|
|
|
|
document.addEventListener('mousemove', drawConnection); |
|
document.addEventListener('mouseup', cancelConnection); |
|
|
|
e.preventDefault(); |
|
} |
|
|
|
|
|
function highlightValidConnectionTargets(sourceNode) { |
|
const sourceType = sourceNode.getAttribute('data-type'); |
|
const sourceId = sourceNode.getAttribute('data-id'); |
|
|
|
document.querySelectorAll('.canvas-node').forEach(node => { |
|
if (node !== sourceNode) { |
|
const nodeType = node.getAttribute('data-type'); |
|
const nodeId = node.getAttribute('data-id'); |
|
const isValidTarget = isValidConnection(sourceType, nodeType, sourceId, nodeId); |
|
|
|
const portIn = node.querySelector('.node-port.port-in'); |
|
if (portIn) { |
|
if (isValidTarget) { |
|
portIn.classList.add('valid-target'); |
|
} else { |
|
portIn.classList.add('invalid-target'); |
|
} |
|
} |
|
} |
|
}); |
|
} |
|
|
|
|
|
function removePortHighlights() { |
|
document.querySelectorAll('.node-port.port-in, .node-port.port-out').forEach(port => { |
|
port.classList.remove('active-port', 'valid-target', 'invalid-target'); |
|
}); |
|
} |
|
|
|
|
|
function isValidConnection(sourceType, targetType, sourceId, targetId) { |
|
|
|
if (sourceType === 'output' || targetType === 'input') { |
|
return false; |
|
} |
|
|
|
|
|
const existingConnection = networkLayers.connections.find( |
|
conn => conn.target === sourceId && conn.source === targetId |
|
); |
|
if (existingConnection) { |
|
return false; |
|
} |
|
|
|
|
|
switch(sourceType) { |
|
case 'input': |
|
return ['hidden', 'conv'].includes(targetType); |
|
case 'conv': |
|
return ['conv', 'pool', 'hidden'].includes(targetType); |
|
case 'pool': |
|
return ['conv', 'hidden'].includes(targetType); |
|
case 'hidden': |
|
return ['hidden', 'output'].includes(targetType); |
|
default: |
|
return false; |
|
} |
|
} |
|
|
|
|
|
function drawConnection(e) { |
|
if (!isConnecting || !connectionLine) return; |
|
|
|
const canvasRect = canvas.getBoundingClientRect(); |
|
const portOut = startNode.querySelector('.node-port.port-out'); |
|
const portRect = portOut.getBoundingClientRect(); |
|
|
|
|
|
const startX = portRect.left + portRect.width / 2 - canvasRect.left; |
|
const startY = portRect.top + portRect.height / 2 - canvasRect.top; |
|
const endX = e.clientX - canvasRect.left; |
|
const endY = e.clientY - canvasRect.top; |
|
|
|
|
|
const length = Math.sqrt(Math.pow(endX - startX, 2) + Math.pow(endY - startY, 2)); |
|
const angle = Math.atan2(endY - startY, endX - startX) * 180 / Math.PI; |
|
|
|
|
|
connectionLine.style.width = `${length}px`; |
|
connectionLine.style.transform = `rotate(${angle}deg)`; |
|
|
|
|
|
document.querySelectorAll('.canvas-node').forEach(node => { |
|
if (node !== startNode) { |
|
const portIn = node.querySelector('.node-port.port-in'); |
|
if (portIn) { |
|
const portInRect = portIn.getBoundingClientRect(); |
|
|
|
|
|
if (e.clientX >= portInRect.left && e.clientX <= portInRect.right && |
|
e.clientY >= portInRect.top && e.clientY <= portInRect.bottom) { |
|
portIn.classList.add('port-hover'); |
|
} else { |
|
portIn.classList.remove('port-hover'); |
|
} |
|
} |
|
} |
|
}); |
|
} |
|
|
|
|
|
function cancelConnection(e) { |
|
if (!isConnecting) return; |
|
|
|
|
|
let targetNode = null; |
|
document.querySelectorAll('.canvas-node').forEach(node => { |
|
if (node !== startNode) { |
|
const portIn = node.querySelector('.node-port.port-in'); |
|
if (portIn) { |
|
const portRect = portIn.getBoundingClientRect(); |
|
|
|
if (e.clientX >= portRect.left && e.clientX <= portRect.right && |
|
e.clientY >= portRect.top && e.clientY <= portRect.bottom) { |
|
|
|
|
|
const sourceType = startNode.getAttribute('data-type'); |
|
const targetType = node.getAttribute('data-type'); |
|
const sourceId = startNode.getAttribute('data-id'); |
|
const targetId = node.getAttribute('data-id'); |
|
|
|
if (isValidConnection(sourceType, targetType, sourceId, targetId)) { |
|
targetNode = node; |
|
} |
|
} |
|
} |
|
} |
|
}); |
|
|
|
|
|
if (targetNode) { |
|
endConnection(targetNode); |
|
} else { |
|
|
|
if (connectionLine && connectionLine.parentNode) { |
|
connectionLine.parentNode.removeChild(connectionLine); |
|
} |
|
} |
|
|
|
|
|
removePortHighlights(); |
|
document.querySelectorAll('.node-port').forEach(port => { |
|
port.classList.remove('port-hover'); |
|
}); |
|
|
|
|
|
isConnecting = false; |
|
startNode = null; |
|
connectionLine = null; |
|
|
|
|
|
document.removeEventListener('mousemove', drawConnection); |
|
document.removeEventListener('mouseup', cancelConnection); |
|
} |
|
|
|
|
|
function endConnection(targetNode) { |
|
if (!isConnecting || !connectionLine || !startNode) return; |
|
|
|
const sourceType = startNode.getAttribute('data-type'); |
|
const targetType = targetNode.getAttribute('data-type'); |
|
const sourceId = startNode.getAttribute('data-id'); |
|
const targetId = targetNode.getAttribute('data-id'); |
|
|
|
|
|
if (isValidConnection(sourceType, targetType, sourceId, targetId)) { |
|
|
|
if (connectionLine && connectionLine.parentNode) { |
|
connectionLine.parentNode.removeChild(connectionLine); |
|
} |
|
|
|
|
|
const connection = document.createElement('div'); |
|
connection.className = 'connection'; |
|
connection.setAttribute('data-source', sourceId); |
|
connection.setAttribute('data-target', targetId); |
|
|
|
|
|
canvas.appendChild(connection); |
|
|
|
|
|
const sourcePort = startNode.querySelector('.node-port.port-out'); |
|
const targetPort = targetNode.querySelector('.node-port.port-in'); |
|
|
|
if (sourcePort && targetPort) { |
|
const sourceRect = sourcePort.getBoundingClientRect(); |
|
const targetRect = targetPort.getBoundingClientRect(); |
|
const canvasRect = canvas.getBoundingClientRect(); |
|
|
|
const startX = sourceRect.left + sourceRect.width / 2 - canvasRect.left; |
|
const startY = sourceRect.top + sourceRect.height / 2 - canvasRect.top; |
|
const endX = targetRect.left + targetRect.width / 2 - canvasRect.left; |
|
const endY = targetRect.top + targetRect.height / 2 - canvasRect.top; |
|
|
|
const length = Math.sqrt(Math.pow(endX - startX, 2) + Math.pow(endY - startY, 2)); |
|
const angle = Math.atan2(endY - startY, endX - startX) * 180 / Math.PI; |
|
|
|
connection.style.left = `${startX}px`; |
|
connection.style.top = `${startY}px`; |
|
connection.style.width = `${length}px`; |
|
connection.style.transform = `rotate(${angle}deg)`; |
|
} |
|
|
|
|
|
const sourceLayerIndex = networkLayers.layers.findIndex(layer => layer.id === sourceId); |
|
const targetLayerIndex = networkLayers.layers.findIndex(layer => layer.id === targetId); |
|
|
|
if (sourceLayerIndex !== -1 && targetLayerIndex !== -1) { |
|
networkLayers.connections.push({ |
|
source: sourceId, |
|
target: targetId |
|
}); |
|
|
|
|
|
if (!networkLayers.layers[sourceLayerIndex].connections) { |
|
networkLayers.layers[sourceLayerIndex].connections = []; |
|
} |
|
if (!networkLayers.layers[targetLayerIndex].connections) { |
|
networkLayers.layers[targetLayerIndex].connections = []; |
|
} |
|
|
|
|
|
networkLayers.layers[sourceLayerIndex].connections.push(targetId); |
|
networkLayers.layers[targetLayerIndex].connections.push(sourceId); |
|
|
|
|
|
const sourceConfig = networkLayers.layers[sourceLayerIndex].config; |
|
|
|
if (sourceConfig && sourceConfig.outputShape) { |
|
|
|
if (!targetNode.layerConfig) { |
|
targetNode.layerConfig = {}; |
|
} |
|
|
|
|
|
targetNode.layerConfig.inputShape = [...sourceConfig.outputShape]; |
|
|
|
|
|
updateNodeParameters(targetNode, targetType, sourceConfig); |
|
|
|
|
|
updateDownstreamNodes(targetId); |
|
|
|
|
|
forceUpdateNetworkParameters(); |
|
} |
|
} |
|
|
|
|
|
document.dispatchEvent(new CustomEvent('networkUpdated', { |
|
detail: networkLayers |
|
})); |
|
} |
|
|
|
|
|
isConnecting = false; |
|
startNode = null; |
|
connectionLine = null; |
|
|
|
|
|
document.removeEventListener('mousemove', drawConnection); |
|
document.removeEventListener('mouseup', cancelConnection); |
|
} |
|
|
|
|
|
function updateConnections(specificNodeId = null) { |
|
console.log(`[DEBUG] updateConnections called ${specificNodeId ? 'for node: ' + specificNodeId : 'for all connections'}`); |
|
|
|
|
|
let connections; |
|
if (specificNodeId) { |
|
connections = document.querySelectorAll(`.connection[data-source="${specificNodeId}"], .connection[data-target="${specificNodeId}"]`); |
|
} else { |
|
connections = document.querySelectorAll('.connection:not(.temp-connection)'); |
|
} |
|
|
|
console.log(`[DEBUG] Updating ${connections.length} connections`); |
|
|
|
connections.forEach(connection => { |
|
const sourceId = connection.getAttribute('data-source'); |
|
const targetId = connection.getAttribute('data-target'); |
|
|
|
const sourceNode = document.querySelector(`.canvas-node[data-id="${sourceId}"]`); |
|
const targetNode = document.querySelector(`.canvas-node[data-id="${targetId}"]`); |
|
|
|
if (sourceNode && targetNode) { |
|
const sourcePort = sourceNode.querySelector('.node-port.port-out'); |
|
const targetPort = targetNode.querySelector('.node-port.port-in'); |
|
|
|
if (sourcePort && targetPort) { |
|
const canvasRect = canvas.getBoundingClientRect(); |
|
const sourceRect = sourcePort.getBoundingClientRect(); |
|
const targetRect = targetPort.getBoundingClientRect(); |
|
|
|
const startX = sourceRect.left + sourceRect.width / 2 - canvasRect.left; |
|
const startY = sourceRect.top + sourceRect.height / 2 - canvasRect.top; |
|
const endX = targetRect.left + targetRect.width / 2 - canvasRect.left; |
|
const endY = targetRect.top + targetRect.height / 2 - canvasRect.top; |
|
|
|
const length = Math.sqrt(Math.pow(endX - startX, 2) + Math.pow(endY - startY, 2)); |
|
const angle = Math.atan2(endY - startY, endX - startX) * 180 / Math.PI; |
|
|
|
connection.style.left = `${startX}px`; |
|
connection.style.top = `${startY}px`; |
|
connection.style.width = `${length}px`; |
|
connection.style.transform = `rotate(${angle}deg)`; |
|
} |
|
} else { |
|
|
|
if (connection.parentNode) { |
|
console.log(`[DEBUG] Removing orphaned connection between ${sourceId} and ${targetId}`); |
|
connection.parentNode.removeChild(connection); |
|
|
|
|
|
const connIndex = networkLayers.connections.findIndex(conn => |
|
conn.source === sourceId && conn.target === targetId |
|
); |
|
if (connIndex !== -1) { |
|
networkLayers.connections.splice(connIndex, 1); |
|
} |
|
} |
|
} |
|
}); |
|
} |
|
|
|
|
|
function updateNodeParameters(node, nodeType, sourceConfig) { |
|
if (!node || !nodeType || !sourceConfig) return; |
|
|
|
const nodeId = node.getAttribute('data-id'); |
|
|
|
|
|
if (!node.layerConfig) { |
|
node.layerConfig = {}; |
|
} |
|
|
|
|
|
if (sourceConfig.outputShape) { |
|
node.layerConfig.inputShape = [...sourceConfig.outputShape]; |
|
|
|
|
|
switch(nodeType) { |
|
case 'hidden': |
|
node.layerConfig.outputShape = [node.layerConfig.units]; |
|
break; |
|
case 'output': |
|
node.layerConfig.outputShape = [node.layerConfig.units]; |
|
break; |
|
case 'conv': |
|
|
|
if (window.neuralNetwork && window.neuralNetwork.calculateOutputShape) { |
|
node.layerConfig.outputShape = window.neuralNetwork.calculateOutputShape( |
|
'conv', |
|
node.layerConfig.inputShape, |
|
node.layerConfig |
|
); |
|
} |
|
break; |
|
case 'pool': |
|
|
|
if (window.neuralNetwork && window.neuralNetwork.calculateOutputShape) { |
|
node.layerConfig.outputShape = window.neuralNetwork.calculateOutputShape( |
|
'pool', |
|
node.layerConfig.inputShape, |
|
node.layerConfig |
|
); |
|
} |
|
break; |
|
} |
|
} |
|
|
|
|
|
let newParams = 0; |
|
|
|
if (window.neuralNetwork && window.neuralNetwork.calculateParameters) { |
|
newParams = window.neuralNetwork.calculateParameters( |
|
nodeType, |
|
node.layerConfig, |
|
sourceConfig |
|
); |
|
} else { |
|
|
|
switch(nodeType) { |
|
case 'hidden': |
|
if (node.layerConfig.inputShape && node.layerConfig.units) { |
|
|
|
const inputSize = node.layerConfig.inputShape[0]; |
|
newParams = (inputSize * node.layerConfig.units) + node.layerConfig.units; |
|
} |
|
break; |
|
case 'output': |
|
if (node.layerConfig.inputShape && node.layerConfig.units) { |
|
|
|
const inputSize = node.layerConfig.inputShape[0]; |
|
newParams = (inputSize * node.layerConfig.units) + node.layerConfig.units; |
|
} |
|
break; |
|
case 'conv': |
|
if (node.layerConfig.inputShape && node.layerConfig.filters && node.layerConfig.kernelSize) { |
|
|
|
const inputChannels = node.layerConfig.inputShape.length > 2 ? node.layerConfig.inputShape[2] : 1; |
|
newParams = (node.layerConfig.kernelSize[0] * node.layerConfig.kernelSize[1] * |
|
inputChannels * node.layerConfig.filters) + node.layerConfig.filters; |
|
} |
|
break; |
|
case 'pool': |
|
|
|
newParams = 0; |
|
break; |
|
} |
|
} |
|
|
|
|
|
if (newParams !== undefined) { |
|
|
|
node.layerConfig.parameters = newParams; |
|
|
|
|
|
const layerIndex = networkLayers.layers.findIndex(layer => layer.id === nodeId); |
|
if (layerIndex !== -1) { |
|
networkLayers.layers[layerIndex].parameters = newParams; |
|
if (networkLayers.layers[layerIndex].config) { |
|
networkLayers.layers[layerIndex].config.parameters = newParams; |
|
|
|
|
|
if (node.layerConfig.outputShape) { |
|
networkLayers.layers[layerIndex].config.outputShape = [...node.layerConfig.outputShape]; |
|
} |
|
} |
|
} |
|
|
|
|
|
const paramsDisplay = node.querySelector('.node-parameters'); |
|
if (paramsDisplay) { |
|
paramsDisplay.textContent = `Params: ${formatNumber(newParams)}`; |
|
} |
|
} |
|
|
|
|
|
if (node.layerConfig.inputShape) { |
|
const inputShapeDisplay = node.querySelector('.input-shape'); |
|
if (inputShapeDisplay) { |
|
inputShapeDisplay.textContent = `[${node.layerConfig.inputShape.join(' × ')}]`; |
|
} |
|
} |
|
|
|
|
|
if (node.layerConfig.outputShape) { |
|
const outputShapeDisplay = node.querySelector('.output-shape'); |
|
if (outputShapeDisplay) { |
|
outputShapeDisplay.textContent = `[${node.layerConfig.outputShape.join(' × ')}]`; |
|
} |
|
} |
|
|
|
|
|
updateNodeDimensions(node); |
|
|
|
|
|
setTimeout(() => { |
|
|
|
const originalDisplay = node.style.display; |
|
node.style.display = 'none'; |
|
|
|
void node.offsetHeight; |
|
node.style.display = originalDisplay; |
|
}, 10); |
|
} |
|
|
|
|
|
function updateNodeDimensions(node) { |
|
if (!node || !node.layerConfig) return; |
|
|
|
const nodeType = node.getAttribute('data-type'); |
|
const dimensionsSection = node.querySelector('.node-dimensions'); |
|
if (!dimensionsSection) return; |
|
|
|
let dimensionsText = ''; |
|
|
|
|
|
switch (nodeType) { |
|
case 'input': |
|
if (node.layerConfig.shape) { |
|
dimensionsText = node.layerConfig.shape.join(' × '); |
|
} |
|
break; |
|
case 'hidden': |
|
case 'output': |
|
dimensionsText = node.layerConfig.units ? node.layerConfig.units.toString() : '?'; |
|
break; |
|
case 'conv': |
|
if (node.layerConfig.inputShape && node.layerConfig.outputShape) { |
|
dimensionsText = `${node.layerConfig.inputShape.join('×')} → ${node.layerConfig.outputShape.join('×')}`; |
|
} else if (node.layerConfig.filters) { |
|
dimensionsText = `? → ${node.layerConfig.filters} filters`; |
|
} |
|
break; |
|
case 'pool': |
|
if (node.layerConfig.inputShape && node.layerConfig.outputShape) { |
|
dimensionsText = `${node.layerConfig.inputShape.join('×')} → ${node.layerConfig.outputShape.join('×')}`; |
|
} else { |
|
dimensionsText = `? → ?`; |
|
} |
|
break; |
|
case 'linear': |
|
if (node.layerConfig.inputFeatures && node.layerConfig.outputFeatures) { |
|
dimensionsText = `${node.layerConfig.inputFeatures} → ${node.layerConfig.outputFeatures}`; |
|
} |
|
break; |
|
} |
|
|
|
if (dimensionsText) { |
|
dimensionsSection.textContent = dimensionsText; |
|
node.setAttribute('data-dimensions', dimensionsText); |
|
} |
|
} |
|
|
|
|
|
function updateDownstreamNodes(nodeId) { |
|
|
|
const outgoingConnections = networkLayers.connections.filter(conn => conn.source === nodeId); |
|
|
|
outgoingConnections.forEach(conn => { |
|
const targetId = conn.target; |
|
const targetNode = document.querySelector(`.canvas-node[data-id="${targetId}"]`); |
|
const sourceNode = document.querySelector(`.canvas-node[data-id="${nodeId}"]`); |
|
|
|
if (targetNode && sourceNode) { |
|
const targetType = targetNode.getAttribute('data-type'); |
|
const sourceType = sourceNode.getAttribute('data-type'); |
|
|
|
|
|
if (!targetType || !sourceType) return; |
|
|
|
|
|
const sourceIndex = networkLayers.layers.findIndex(layer => layer.id === nodeId); |
|
const targetIndex = networkLayers.layers.findIndex(layer => layer.id === targetId); |
|
|
|
if (sourceIndex !== -1 && targetIndex !== -1) { |
|
const sourceConfig = networkLayers.layers[sourceIndex].config; |
|
|
|
|
|
if (sourceConfig && sourceConfig.outputShape) { |
|
|
|
if (!targetNode.layerConfig) { |
|
targetNode.layerConfig = {}; |
|
} |
|
|
|
targetNode.layerConfig.inputShape = [...sourceConfig.outputShape]; |
|
networkLayers.layers[targetIndex].config.inputShape = [...sourceConfig.outputShape]; |
|
|
|
|
|
updateNodeParameters(targetNode, targetType, sourceConfig); |
|
|
|
|
|
updateDownstreamNodes(targetId); |
|
} |
|
} |
|
} |
|
}); |
|
} |
|
|
|
|
|
function forceUpdateNetworkParameters() { |
|
|
|
const targetIds = new Set(networkLayers.connections.map(conn => conn.target)); |
|
const rootNodeIds = networkLayers.layers |
|
.filter(layer => !targetIds.has(layer.id)) |
|
.map(layer => layer.id); |
|
|
|
|
|
rootNodeIds.forEach(nodeId => { |
|
updateDownstreamNodes(nodeId); |
|
}); |
|
|
|
|
|
document.dispatchEvent(new CustomEvent('networkUpdated', { |
|
detail: networkLayers |
|
})); |
|
} |
|
|
|
|
|
function getNetworkArchitecture() { |
|
return networkLayers; |
|
} |
|
|
|
|
|
function clearAllNodes() { |
|
|
|
document.querySelectorAll('.canvas-node, .connection').forEach(el => { |
|
el.parentNode.removeChild(el); |
|
}); |
|
|
|
|
|
networkLayers = { |
|
layers: [], |
|
connections: [] |
|
}; |
|
|
|
|
|
window.neuralNetwork.resetLayerCounter(); |
|
|
|
|
|
const canvasHint = document.querySelector('.canvas-hint'); |
|
if (canvasHint) { |
|
canvasHint.style.display = 'block'; |
|
} |
|
|
|
|
|
const event = new CustomEvent('networkUpdated', { detail: networkLayers }); |
|
document.dispatchEvent(event); |
|
} |
|
|
|
|
|
function openLayerEditor(node) { |
|
if (!node) return; |
|
|
|
const nodeId = node.getAttribute('data-id'); |
|
const nodeType = node.getAttribute('data-type'); |
|
const nodeName = node.getAttribute('data-name'); |
|
const dimensions = node.getAttribute('data-dimensions'); |
|
|
|
|
|
const event = new CustomEvent('openLayerEditor', { |
|
detail: { |
|
id: nodeId, |
|
type: nodeType, |
|
name: nodeName, |
|
dimensions: dimensions, |
|
node: node |
|
} |
|
}); |
|
document.dispatchEvent(event); |
|
} |
|
|
|
|
|
function createSVGContainer() { |
|
const svgContainer = document.createElementNS('http://www.w3.org/2000/svg', 'svg'); |
|
svgContainer.classList.add('svg-container'); |
|
svgContainer.style.position = 'absolute'; |
|
svgContainer.style.top = '0'; |
|
svgContainer.style.left = '0'; |
|
svgContainer.style.width = '100%'; |
|
svgContainer.style.height = '100%'; |
|
svgContainer.style.pointerEvents = 'none'; |
|
svgContainer.style.zIndex = '5'; |
|
canvas.appendChild(svgContainer); |
|
return svgContainer; |
|
} |
|
|
|
|
|
window.dragDrop = { |
|
getNetworkArchitecture, |
|
clearAllNodes, |
|
updateConnections |
|
}; |
|
|
|
|
|
window.startDrag = startDrag; |
|
window.dragNode = dragNode; |
|
window.stopDrag = stopDrag; |
|
window.deleteNode = deleteNode; |
|
} |