Mattral commited on
Commit
1219288
·
verified ·
1 Parent(s): c6b6224

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -0
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import cv2
4
+ import plotly.graph_objects as go
5
+ from plotly.subplots import make_subplots
6
+ import pandas as pd
7
+
8
+ # FFT processing functions
9
+ def apply_fft(image):
10
+ """Apply FFT to each channel of the image and return shifted FFT channels."""
11
+ fft_channels = []
12
+ for channel in cv2.split(image):
13
+ fft = np.fft.fft2(channel)
14
+ fft_shifted = np.fft.fftshift(fft)
15
+ fft_channels.append(fft_shifted)
16
+ return fft_channels
17
+
18
+ def filter_fft_percentage(fft_channels, percentage):
19
+ """Filter FFT channels to keep top percentage of magnitudes."""
20
+ filtered_fft = []
21
+ for fft_data in fft_channels:
22
+ magnitude = np.abs(fft_data)
23
+ sorted_mag = np.sort(magnitude.flatten())[::-1]
24
+ num_keep = int(len(sorted_mag) * percentage / 100)
25
+ threshold = sorted_mag[num_keep - 1] if num_keep > 0 else 0
26
+ mask = magnitude >= threshold
27
+ filtered_fft.append(fft_data * mask)
28
+ return filtered_fft
29
+
30
+ def inverse_fft(filtered_fft):
31
+ """Reconstruct image from filtered FFT channels."""
32
+ reconstructed_channels = []
33
+ for fft_data in filtered_fft:
34
+ fft_ishift = np.fft.ifftshift(fft_data)
35
+ img_reconstructed = np.fft.ifft2(fft_ishift).real
36
+ img_normalized = cv2.normalize(img_reconstructed, None, 0, 255, cv2.NORM_MINMAX)
37
+ reconstructed_channels.append(img_normalized.astype(np.uint8))
38
+ return cv2.merge(reconstructed_channels)
39
+
40
+ def create_3d_plot(fft_channels, downsample_factor=1):
41
+ """Create interactive 3D surface plots using Plotly."""
42
+ fig = make_subplots(
43
+ rows=3, cols=2,
44
+ specs=[[{'type': 'scene'}, {'type': 'scene'}],
45
+ [{'type': 'scene'}, {'type': 'scene'}],
46
+ [{'type': 'scene'}, {'type': 'scene'}]],
47
+ subplot_titles=(
48
+ 'Blue - Magnitude', 'Blue - Phase',
49
+ 'Green - Magnitude', 'Green - Phase',
50
+ 'Red - Magnitude', 'Red - Phase'
51
+ )
52
+ )
53
+
54
+ channel_names = ['Blue', 'Green', 'Red']
55
+
56
+ for i, fft_data in enumerate(fft_channels):
57
+ # Downsample data for better performance
58
+ fft_down = fft_data[::downsample_factor, ::downsample_factor]
59
+ magnitude = np.abs(fft_down)
60
+ phase = np.angle(fft_down)
61
+
62
+ # Create grid coordinates
63
+ rows, cols = magnitude.shape
64
+ x = np.linspace(-cols//2, cols//2, cols)
65
+ y = np.linspace(-rows//2, rows//2, rows)
66
+ X, Y = np.meshgrid(x, y)
67
+
68
+ # Magnitude plot
69
+ fig.add_trace(
70
+ go.Surface(x=X, y=Y, z=magnitude, colorscale='Viridis', showscale=False),
71
+ row=i+1, col=1
72
+ )
73
+
74
+ # Phase plot
75
+ fig.add_trace(
76
+ go.Surface(x=X, y=Y, z=phase, colorscale='Inferno', showscale=False),
77
+ row=i+1, col=2
78
+ )
79
+
80
+ # Update layout for better visualization
81
+ fig.update_layout(
82
+ height=1500,
83
+ width=1200,
84
+ margin=dict(l=0, r=0, b=0, t=30),
85
+ scene_camera=dict(eye=dict(x=1.5, y=1.5, z=0.5)),
86
+ scene=dict(
87
+ xaxis=dict(title='Frequency X'),
88
+ yaxis=dict(title='Frequency Y'),
89
+ zaxis=dict(title='Magnitude/Phase')
90
+ )
91
+ )
92
+ return fig
93
+
94
+ # Streamlit UI
95
+ st.set_page_config(layout="wide")
96
+ st.title("Interactive Frequency Domain Analysis")
97
+
98
+ # Introduction Text
99
+ st.subheader("Introduction to FFT and Image Filtering")
100
+ st.write(
101
+ """Fast Fourier Transform (FFT) is a technique to transform an image from the spatial domain to the frequency domain.
102
+ In this domain, each frequency represents a different aspect of the image's texture and details.
103
+ By filtering out certain frequencies, you can modify the image's appearance, enhancing or suppressing certain features."""
104
+ )
105
+
106
+ uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'])
107
+
108
+ if uploaded_file is not None:
109
+ # Read and display original image
110
+ file_bytes = np.frombuffer(uploaded_file.getvalue(), np.uint8)
111
+ image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
112
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
113
+ st.image(image_rgb, caption="Original Image", use_column_width=True)
114
+
115
+ # Process FFT and store in session state
116
+ if 'fft_channels' not in st.session_state:
117
+ st.session_state.fft_channels = apply_fft(image)
118
+
119
+ # Create a form to submit frequency percentage selection
120
+ with st.form(key='fft_form'):
121
+ percentage = st.slider(
122
+ "Percentage of frequencies to retain:",
123
+ min_value=0.1, max_value=100.0, value=10.0, step=0.1,
124
+ help="Adjust the slider to select what portion of frequency components to keep. Lower values blur the image."
125
+ )
126
+ submit_button = st.form_submit_button(label="Apply Filter")
127
+
128
+ if submit_button:
129
+ # Apply filtering and reconstruct image
130
+ filtered_fft = filter_fft_percentage(st.session_state.fft_channels, percentage)
131
+ reconstructed = inverse_fft(filtered_fft)
132
+ reconstructed_rgb = cv2.cvtColor(reconstructed, cv2.COLOR_BGR2RGB)
133
+ st.image(reconstructed_rgb, caption="Reconstructed Image", use_column_width=True)
134
+
135
+ # Display FFT Data in Table Format
136
+ st.subheader("Frequency Data of Each Channel")
137
+ fft_data_dict = {}
138
+ for i, channel_name in enumerate(['Blue', 'Green', 'Red']):
139
+ magnitude = np.abs(st.session_state.fft_channels[i])
140
+ phase = np.angle(st.session_state.fft_channels[i])
141
+ fft_data_dict[channel_name] = {'Magnitude': magnitude, 'Phase': phase}
142
+
143
+ # Create DataFrames for each channel's FFT data
144
+ for channel_name, data in fft_data_dict.items():
145
+ st.write(f"### {channel_name} Channel FFT Data")
146
+ magnitude_df = pd.DataFrame(data['Magnitude'])
147
+ phase_df = pd.DataFrame(data['Phase'])
148
+ st.write("#### Magnitude Data:")
149
+ st.dataframe(magnitude_df.head(10)) # Display first 10 rows for brevity
150
+ st.write("#### Phase Data:")
151
+ st.dataframe(phase_df.head(10)) # Display first 10 rows for brevity
152
+
153
+ # Download button for reconstructed image
154
+ _, encoded_img = cv2.imencode('.png', reconstructed)
155
+ st.download_button(
156
+ "Download Reconstructed Image",
157
+ encoded_img.tobytes(),
158
+ "reconstructed.png",
159
+ "image/png"
160
+ )
161
+
162
+ # 3D visualization controls
163
+ st.subheader("3D Frequency Components Visualization")
164
+ downsample = st.slider(
165
+ "Downsampling factor for 3D plots:",
166
+ min_value=1, max_value=20, value=5,
167
+ help="Controls the resolution of the 3D surface plots. Higher values improve performance but reduce the plot's detail."
168
+ )
169
+
170
+ # Generate and display 3D plots
171
+ fig = create_3d_plot(filtered_fft, downsample)
172
+ st.plotly_chart(fig, use_container_width=True)