beta3 commited on
Commit
feaf2ab
·
verified ·
1 Parent(s): aa403b8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +416 -0
app.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ from plotly.subplots import make_subplots
6
+
7
+ # Load data
8
+ def load_data():
9
+ """Load the dataset from a local CSV file"""
10
+ df = pd.read_csv("EEG_Eye_State.csv")
11
+ return df
12
+
13
+ # Initialize data
14
+ df = load_data()
15
+
16
+ # List of EEG channels
17
+ eeg_channels = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1',
18
+ 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
19
+
20
+ def plot_eeg_signals(start_time, duration, eye_state_filter, selected_channels):
21
+ """
22
+ Visualize the selected EEG signals
23
+ """
24
+ # Calculate indices based on time (128 Hz)
25
+ sampling_rate = 128
26
+ start_idx = int(start_time * sampling_rate)
27
+ end_idx = start_idx + int(duration * sampling_rate)
28
+
29
+ # Filter data segment
30
+ df_segment = df.iloc[start_idx:end_idx].copy()
31
+
32
+ # Filter by eye state if selected
33
+ if eye_state_filter != "Both":
34
+ filter_value = 1 if eye_state_filter == "Closed" else 0
35
+ df_segment = df_segment[df_segment['eyeDetection'] == filter_value]
36
+
37
+ if len(df_segment) == 0:
38
+ return None
39
+
40
+ # Create subplots
41
+ n_channels = len(selected_channels)
42
+ fig = make_subplots(
43
+ rows=n_channels,
44
+ cols=1,
45
+ shared_xaxes=True,
46
+ vertical_spacing=0.02,
47
+ subplot_titles=selected_channels
48
+ )
49
+
50
+ # Create time axis
51
+ time_axis = np.arange(len(df_segment)) / sampling_rate + start_time
52
+
53
+ # Add each channell
54
+ for idx, channel in enumerate(selected_channels, 1):
55
+ # Color based on eye state
56
+ colors = ['red' if x == 1 else 'blue' for x in df_segment['eyeDetection']]
57
+
58
+ fig.add_trace(
59
+ go.Scatter(
60
+ x=time_axis,
61
+ y=df_segment[channel],
62
+ mode='lines',
63
+ name=channel,
64
+ line=dict(color='steelblue', width=1),
65
+ showlegend=False
66
+ ),
67
+ row=idx, col=1
68
+ )
69
+
70
+ # Add shaded areas for closed eyes
71
+ eye_closed_mask = df_segment['eyeDetection'] == 1
72
+ if eye_closed_mask.any():
73
+ closed_indices = np.where(eye_closed_mask)[0]
74
+ # Group consecutive indices
75
+ if len(closed_indices) > 0:
76
+ groups = np.split(closed_indices, np.where(np.diff(closed_indices) != 1)[0] + 1)
77
+ for group in groups:
78
+ if len(group) > 0:
79
+ fig.add_vrect(
80
+ x0=time_axis[group[0]],
81
+ x1=time_axis[group[-1]],
82
+ fillcolor="red", opacity=0.1,
83
+ layer="below", line_width=0,
84
+ row=idx, col=1
85
+ )
86
+
87
+ # Update layout
88
+ fig.update_xaxes(title_text="Time (seconds)", row=n_channels, col=1)
89
+ fig.update_yaxes(title_text="Amplitude (μV)")
90
+
91
+ fig.update_layout(
92
+ height=200 * n_channels,
93
+ title_text=f"EEG Signals - {eye_state_filter} Eyes",
94
+ showlegend=False,
95
+ hovermode='x unified'
96
+ )
97
+
98
+ return fig
99
+
100
+ def plot_channel_comparison(channels, eye_state_filter, remove_outliers):
101
+ """
102
+ Compare specific channels between open and closed eyes
103
+ """
104
+ if not channels:
105
+ return None
106
+
107
+ n_channels = len(channels)
108
+
109
+ # Determine number of columns based on filter
110
+ n_cols = 2 if eye_state_filter == "Both" else 1
111
+
112
+ if eye_state_filter == "Both":
113
+ subplot_titles = [f'{ch} - Eyes Open' if i % 2 == 0 else f'{ch} - Eyes Closed'
114
+ for ch in channels for i in range(2)]
115
+ specs = [[{'type': 'box'}, {'type': 'histogram'}] for _ in range(n_channels)]
116
+ else:
117
+ state_label = "Eyes Open" if eye_state_filter == "Open" else "Eyes Closed"
118
+ subplot_titles = [f'{ch} - {state_label}' for ch in channels]
119
+ specs = [[{'type': 'box'}] for _ in range(n_channels)]
120
+
121
+ fig = make_subplots(
122
+ rows=n_channels, cols=n_cols,
123
+ subplot_titles=subplot_titles,
124
+ specs=specs,
125
+ vertical_spacing=0.08
126
+ )
127
+
128
+ for idx, channel in enumerate(channels, 1):
129
+ df_open = df[df['eyeDetection'] == 0][channel]
130
+ df_closed = df[df['eyeDetection'] == 1][channel]
131
+
132
+ # Filter outliers if requested
133
+ if remove_outliers:
134
+ def filter_outliers(data):
135
+ Q1 = data.quantile(0.25)
136
+ Q3 = data.quantile(0.75)
137
+ IQR = Q3 - Q1
138
+ lower_bound = Q1 - 1.5 * IQR
139
+ upper_bound = Q3 + 1.5 * IQR
140
+ return data[(data >= lower_bound) & (data <= upper_bound)]
141
+
142
+ df_open = filter_outliers(df_open)
143
+ df_closed = filter_outliers(df_closed)
144
+
145
+ if eye_state_filter in ["Both", "Open"]:
146
+ # Boxplot for Open
147
+ fig.add_trace(
148
+ go.Box(y=df_open, name=f'{channel} Open', marker_color='blue',
149
+ showlegend=(idx==1)),
150
+ row=idx, col=1
151
+ )
152
+
153
+ if eye_state_filter in ["Both", "Closed"]:
154
+ # Boxplot for Closed
155
+ fig.add_trace(
156
+ go.Box(y=df_closed, name=f'{channel} Closed', marker_color='red',
157
+ showlegend=(idx==1)),
158
+ row=idx, col=1
159
+ )
160
+
161
+ # Histogram only if "Both"
162
+ if eye_state_filter == "Both":
163
+ # Histograma Open
164
+ fig.add_trace(
165
+ go.Histogram(x=df_open, name=f'{channel} Open', marker_color='blue',
166
+ opacity=0.7, showlegend=False, nbinsx=30),
167
+ row=idx, col=2
168
+ )
169
+ # Histogram Closed
170
+ fig.add_trace(
171
+ go.Histogram(x=df_closed, name=f'{channel} Closed', marker_color='red',
172
+ opacity=0.7, showlegend=False, nbinsx=30),
173
+ row=idx, col=2
174
+ )
175
+
176
+ # Center and adjust histogram axes
177
+ all_data = pd.concat([df_open, df_closed])
178
+ data_min = all_data.min()
179
+ data_max = all_data.max()
180
+ data_range = data_max - data_min
181
+ margin = data_range * 0.1
182
+
183
+ fig.update_xaxes(
184
+ range=[data_min - margin, data_max + margin],
185
+ row=idx, col=2
186
+ )
187
+
188
+ fig.update_layout(
189
+ height=350 * n_channels,
190
+ title_text=f"Channel Distribution Comparison - {eye_state_filter} Eyes",
191
+ showlegend=True
192
+ )
193
+
194
+ if eye_state_filter == "Both":
195
+ fig.update_xaxes(title_text="Amplitude (μV)", row=n_channels, col=2)
196
+ fig.update_yaxes(title_text="Amplitude (μV)")
197
+
198
+ return fig
199
+
200
+ def get_statistics():
201
+ """
202
+ Generate dataset statistics in text format
203
+ """
204
+ stats = []
205
+
206
+ # General information
207
+ total_samples = len(df)
208
+ eyes_open = len(df[df['eyeDetection'] == 0])
209
+ eyes_closed = len(df[df['eyeDetection'] == 1])
210
+ duration = total_samples / 128 # seconds
211
+
212
+ stats.append(f"**Dataset Statistics**")
213
+ stats.append(f"- Total samples: {total_samples:,}")
214
+ stats.append(f"- Duration: {duration:.2f} seconds")
215
+ stats.append(f"- Sampling rate: 128 Hz")
216
+ stats.append(f"- Eyes Open samples: {eyes_open:,} ({eyes_open/total_samples*100:.1f}%)")
217
+ stats.append(f"- Eyes Closed samples: {eyes_closed:,} ({eyes_closed/total_samples*100:.1f}%)")
218
+
219
+ return "\n".join(stats)
220
+
221
+ def get_statistics_table():
222
+ """
223
+ Generate statistics table per channel
224
+ """
225
+ stats_data = []
226
+
227
+ for channel in eeg_channels:
228
+ channel_data = df[channel]
229
+ open_data = df[df['eyeDetection'] == 0][channel]
230
+ closed_data = df[df['eyeDetection'] == 1][channel]
231
+
232
+ stats_data.append({
233
+ 'Channel': channel,
234
+ 'Mean (All)': f"{channel_data.mean():.2f}",
235
+ 'Std (All)': f"{channel_data.std():.2f}",
236
+ 'Mean (Open)': f"{open_data.mean():.2f}",
237
+ 'Mean (Closed)': f"{closed_data.mean():.2f}",
238
+ 'Min': f"{channel_data.min():.2f}",
239
+ 'Max': f"{channel_data.max():.2f}"
240
+ })
241
+
242
+ return pd.DataFrame(stats_data)
243
+
244
+ def plot_correlation_matrix():
245
+ """
246
+ Visualize the correlation matrix between channels
247
+ """
248
+ corr_matrix = df[eeg_channels].corr()
249
+
250
+ fig = go.Figure(data=go.Heatmap(
251
+ z=corr_matrix.values,
252
+ x=eeg_channels,
253
+ y=eeg_channels,
254
+ colorscale='RdBu',
255
+ zmid=0,
256
+ text=corr_matrix.values,
257
+ texttemplate='%{text:.2f}',
258
+ textfont={"size": 9},
259
+ colorbar=dict(title="Correlation")
260
+ ))
261
+
262
+ fig.update_layout(
263
+ title={
264
+ 'text': "EEG Channels Correlation Matrix",
265
+ 'x': 0.5,
266
+ 'xanchor': 'center'
267
+ },
268
+ height=600,
269
+ width=1215,
270
+ xaxis={'side': 'bottom'}
271
+ )
272
+
273
+ return fig
274
+
275
+ # Create Gradio interface
276
+ demo = gr.Blocks(title="EEG Eye State Visualizer")
277
+
278
+ with demo:
279
+
280
+ gr.Markdown("""
281
+ # 🧠 EEG Eye State Visualizer
282
+
283
+ Explore and visualize the EEG Eye State Classification Dataset. This interactive tool allows you to:
284
+ - View EEG signals from 14 channels
285
+ - Compare patterns between open and closed eyes
286
+ - Analyze statistical distributions
287
+ - Examine channel correlations
288
+
289
+ **Dataset Info**: 14,980 samples | 128 Hz sampling rate | 14 EEG channels
290
+ """)
291
+
292
+ with gr.Tab("Signal Viewer"):
293
+ gr.Markdown("### Visualize EEG Signals")
294
+
295
+ with gr.Row():
296
+ with gr.Column(scale=1):
297
+ start_time = gr.Slider(
298
+ minimum=0,
299
+ maximum=117,
300
+ value=0,
301
+ step=0.5,
302
+ label="Start Time (seconds)"
303
+ )
304
+ duration = gr.Slider(
305
+ minimum=1,
306
+ maximum=10,
307
+ value=5,
308
+ step=0.5,
309
+ label="Duration (seconds)"
310
+ )
311
+ eye_state = gr.Radio(
312
+ choices=["Both", "Open", "Closed"],
313
+ value="Both",
314
+ label="Eye State Filter"
315
+ )
316
+ channels = gr.CheckboxGroup(
317
+ choices=eeg_channels,
318
+ value=['AF3', 'F7', 'O1', 'O2'],
319
+ label="Select Channels to Display"
320
+ )
321
+ plot_btn = gr.Button("Generate Plot", variant="primary")
322
+
323
+ with gr.Column(scale=3):
324
+ signal_plot = gr.Plot(label="EEG Signals")
325
+
326
+ plot_btn.click(
327
+ fn=plot_eeg_signals,
328
+ inputs=[start_time, duration, eye_state, channels],
329
+ outputs=signal_plot
330
+ )
331
+
332
+ with gr.Tab("Channel Analysis"):
333
+ gr.Markdown("### Compare Multiple Channels")
334
+
335
+ with gr.Row():
336
+ with gr.Column(scale=1):
337
+ channels_select = gr.CheckboxGroup(
338
+ choices=eeg_channels,
339
+ value=['AF3', 'O1'],
340
+ label="Select Channels to Compare"
341
+ )
342
+ eye_state_compare = gr.Radio(
343
+ choices=["Both", "Open", "Closed"],
344
+ value="Both",
345
+ label="Eye State Filter"
346
+ )
347
+ remove_outliers_check = gr.Checkbox(
348
+ label="Remove Outliers (IQR method)",
349
+ value=False
350
+ )
351
+ compare_btn = gr.Button("Analyze Channels", variant="primary")
352
+
353
+ with gr.Column(scale=3):
354
+ comparison_plot = gr.Plot(label="Channel Comparison")
355
+
356
+ compare_btn.click(
357
+ fn=plot_channel_comparison,
358
+ inputs=[channels_select, eye_state_compare, remove_outliers_check],
359
+ outputs=comparison_plot
360
+ )
361
+
362
+ with gr.Tab("Statistics"):
363
+ gr.Markdown("### Dataset Statistics")
364
+
365
+ stats_text = gr.Markdown(value=get_statistics())
366
+
367
+ gr.Markdown("### Channel Statistics Table (μV)")
368
+ stats_table = gr.Dataframe(
369
+ value=get_statistics_table(),
370
+ interactive=False,
371
+ wrap=True
372
+ )
373
+
374
+ gr.Markdown("### Correlation Matrix")
375
+ with gr.Row():
376
+ corr_plot = gr.Plot(
377
+ value=plot_correlation_matrix(),
378
+ container=True,
379
+ scale=1
380
+ )
381
+
382
+ with gr.Tab("About"):
383
+ gr.Markdown("""
384
+ ## About this Dataset
385
+
386
+ The EEG Eye State Classification Dataset contains continuous EEG measurements from 14 electrodes
387
+ collected during different eye states (open/closed).
388
+
389
+ ### Key Features:
390
+ - **Total Instances**: 14,980 observations
391
+ - **Features**: 14 EEG channel measurements
392
+ - **Sampling Rate**: 128 Hz
393
+ - **Duration**: ~117 seconds
394
+ - **Device**: Emotiv EEG Neuroheadset
395
+
396
+ ### Electrode Placement:
397
+ The 14 channels follow the international 10-20 system:
398
+ - Left hemisphere: AF3, F7, F3, FC5, T7, P7, O1
399
+ - Right hemisphere: O2, P8, T8, FC6, F4, F8, AF4
400
+
401
+ ### Citation:
402
+ ```
403
+ Rösler, O. (2013). EEG Eye State.
404
+ UCI Machine Learning Repository.
405
+ https://doi.org/10.24432/C57G7J
406
+ ```
407
+
408
+ ### Links:
409
+ - [Dataset on Hugging Face](https://huggingface.co/datasets/BrainSpectralAnalytics/eeg-eye-state-classification)
410
+ - [Original UCI Repository](https://archive.ics.uci.edu/dataset/264/eeg+eye+state)
411
+ - [Kaggle Example](https://www.kaggle.com/code/beta3logic/eye-state-eeg-classification-model-using-automl)
412
+ """)
413
+
414
+ # Launch application
415
+ if __name__ == "__main__":
416
+ demo.launch(ssr_mode=False)