"""Visualization helpers for notebook workflows.
Exports plotting utilities for hypnograms, spectra, topographic maps, and
the interactive ``scope`` viewer built on ``ipywidgets`` and Plotly.
"""
import pandas as pd
import numpy as np
from .segsrv import segsrv
[docs]
def default_xy():
"""Return default 2-D scalp electrode locations for a standard 64-channel EEG montage.
Returns
-------
pandas.DataFrame
DataFrame with columns ``['CH', 'X', 'Y']`` giving the
normalised Cartesian coordinates of each electrode (top-down view,
nose pointing up).
"""
vals = [["FP1", "AF7", "AF3", "F1", "F3", "F5", "F7", "FT7",
"FC5", "FC3", "FC1", "C1", "C3", "C5", "T7", "TP7", "CP5",
"CP3", "CP1", "P1", "P3", "P5", "P7", "P9", "PO7", "PO3",
"O1", "IZ", "OZ", "POZ", "PZ", "CPZ", "FPZ", "FP2", "AF8",
"AF4", "AFZ", "FZ", "F2", "F4", "F6", "F8", "FT8", "FC6",
"FC4", "FC2", "FCZ", "CZ", "C2", "C4", "C6", "T8", "TP8",
"CP6", "CP4", "CP2", "P2", "P4", "P6", "P8", "P10", "PO8",
"PO4", "O2"],
[-0.139058, -0.264503, -0.152969, -0.091616, -0.184692,
-0.276864, -0.364058, -0.427975, -0.328783, -0.215938,
-0.110678, -0.1125, -0.225, -0.3375, -0.45, -0.427975,
-0.328783, -0.215938, -0.110678, -0.091616, -0.184692,
-0.276864, -0.364058, -0.4309, -0.264503, -0.152969,
-0.139058, 0, 0, 0, 0, 0, 0, 0.139058, 0.264503, 0.152969,
0, 0, 0.091616, 0.184692, 0.276864, 0.364058, 0.427975,
0.328783, 0.215938, 0.110678, 0, 0, 0.1125, 0.225, 0.3375,
0.45, 0.427975, 0.328783, 0.215938, 0.110678, 0.091616,
0.184692, 0.276864, 0.364058, 0.4309, 0.264503, 0.152969,
0.139058],
[0.430423, 0.373607, 0.341595, 0.251562, 0.252734,
0.263932, 0.285114, 0.173607, 0.162185, 0.152059, 0.14838,
0.05, 0.05, 0.05, 0.05, -0.073607, -0.062185, -0.052059,
-0.04838, -0.151562, -0.152734, -0.163932, -0.185114,
-0.271394, -0.273607, -0.241595, -0.330422, -0.45, -0.35,
-0.25, -0.15, -0.05, 0.45, 0.430423, 0.373607, 0.341595,
0.35, 0.25, 0.251562, 0.252734, 0.263932, 0.285114, 0.173607,
0.162185, 0.152059, 0.14838, 0.15, 0.05, 0.05, 0.05,
0.05, 0.05, -0.073607, -0.062185, -0.052059, -0.04838,
-0.151562, -0.152734, -0.163932, -0.185114, -0.271394,
-0.273607, -0.241595, -0.330422]]
topo = pd.DataFrame(np.array(vals).T, columns=['CH', 'X', 'Y'])
topo[['X', 'Y']] = topo[['X', 'Y']].apply(pd.to_numeric)
return topo
# --------------------------------------------------------------------------------
[docs]
def stgcol(ss):
"""Map a sequence of sleep stage labels to their canonical hex display colours.
Parameters
----------
ss : list of str
Sleep stage labels (e.g. ``['W', 'N1', 'N2', 'R', '?']``).
Returns
-------
list of str
Hex colour string for each label (e.g. ``'#0050C8FF'`` for N2).
Unknown labels are returned unchanged.
"""
stgcols = { 'N1' : "#00BEFAFF" ,
'N2' : "#0050C8FF" ,
'N3' : "#000050FF" ,
'NREM4' : "#000032FF",
'R' : "#FA1432FF",
'W' : "#31AD52FF",
'L' : "#F6F32AFF",
'?' : "#64646464",
None : "#00000000" }
return [ stgcols.get(item,item) for item in ss ]
# --------------------------------------------------------------------------------
[docs]
def stgn(ss):
"""Map a sequence of sleep stage labels to their canonical numeric codes.
Codes: N1 → −1, N2 → −2, N3 → −3, R → 0, W → 1, L/? → 2.
Parameters
----------
ss : list of str
Sleep stage labels.
Returns
-------
list of int
Numeric stage code for each label. Unknown labels are returned
unchanged.
"""
stgns = { 'N1' : -1,
'N2' : -2,
'N3' : -3,
'NREM4' : -4,
'R' : 0,
'W' : 1,
'L' : 2,
'?' : 2,
None : 2 }
return [ stgns.get(item,item) for item in ss ]
# --------------------------------------------------------------------------------
#
# Visualizations
#
# --------------------------------------------------------------------------------
# --------------------------------------------------------------------------------
[docs]
def hypno( ss , e = None , xsize = 20 , ysize = 2 , title = None ):
"""Plot a colour-coded hypnogram from a sequence of sleep stage labels.
Parameters
----------
ss : array-like of str
Per-epoch sleep stage labels (e.g. from
:meth:`~lunapi.instance.inst.stages`).
e : array-like of int, optional
Epoch indices. If omitted, epochs are numbered ``0, 1, 2, …``.
Values are divided by 120 before plotting to convert to hours
(assuming 30-second epochs).
xsize : float, optional
Figure width in inches. Default ``20``.
ysize : float, optional
Figure height in inches. Default ``2``.
title : str, optional
Optional plot title.
Returns
-------
None
The hypnogram is rendered inline via Matplotlib.
"""
import matplotlib.pyplot as plt
ssn = stgn( ss )
if e is None: e = np.arange(0, len(ssn), 1)
e = e/120
plt.figure(figsize=(xsize , ysize ))
plt.plot( e , ssn , c = 'gray' , lw = 0.5 )
plt.scatter( e , ssn , c = stgcol( ss ) , zorder=2.5 , s = 10 )
plt.ylabel('Sleep stage')
plt.xlabel('Time (hrs)')
plt.ylim(-3.5, 2.5)
plt.xlim(0,max(e))
plt.yticks([-3,-2,-1,0,1,2] , ['N3','N2','N1','R','W','?'] )
if ( title != None ): plt.title( title )
plt.show()
# --------------------------------------------------------------------------------
[docs]
def hypno_density( probs , e = None , xsize = 20 , ysize = 2 , title = None ):
"""Plot a stacked-probability hypno-density chart from POPS/SOAP output.
Displays per-epoch posterior stage probabilities as a stacked area plot,
giving an at-a-glance picture of staging confidence across the night.
Parameters
----------
probs : pandas.DataFrame
DataFrame containing columns ``PP_N1``, ``PP_N2``, ``PP_N3``,
``PP_R``, and ``PP_W`` (as returned by the POPS command).
e : ignored, optional
Reserved for future use; currently unused.
xsize : float, optional
Figure width in inches. Default ``20``.
ysize : float, optional
Figure height in inches. Default ``2``.
title : str, optional
Optional plot title.
Returns
-------
None
The chart is rendered inline via Matplotlib.
"""
import matplotlib.pyplot as plt
# no data?
if len(probs) == 0: return
res = probs[ ["PP_N1","PP_N2","PP_N3","PP_R","PP_W" ] ]
ne = len(res)
x = np.arange(1, ne+1, 1)
y = res.to_numpy(dtype=float)
fig, ax = plt.subplots()
xsize = 20
ysize=2.5
fig.set_figheight(ysize)
fig.set_figwidth(xsize)
ax.set_xlabel('Epoch')
ax.set_ylabel('Prob(stage)')
ax.stackplot(x, y.T , colors = stgcol([ 'N1','N2','N3','R','W']) )
ax.set(xlim=(1, ne), xticks=[ 1 , ne ] ,
ylim=(0, 1), yticks=np.arange(0, 1))
plt.show()
# --------------------------------------------------------------------------------
[docs]
def psd(df , ch, var = 'PSD' , minf = None, maxf = None, minp = None, maxp = None ,
xlines = None , ylines = None, dB = False ):
"""Plot power spectral density from a Luna ``PSD`` or ``MTM`` ``CH_F`` table.
Parameters
----------
df : pandas.DataFrame
Result table with at least columns ``'CH'``, ``'F'``, and the
power variable named by *var*. Typically the ``'PSD: CH_F'`` or
``'MTM: CH_F'`` table returned by :meth:`~lunapi.instance.inst.proc`.
ch : str or list of str
Channel label(s) to plot.
var : str, optional
Column name containing the power values (``'PSD'`` or ``'MTM'``).
Default ``'PSD'``.
minf : float, optional
Minimum frequency (Hz) for the x-axis. Default: data minimum.
maxf : float, optional
Maximum frequency (Hz) for the x-axis. Default: data maximum.
minp : float, optional
Minimum power for the y-axis. Default: data minimum.
maxp : float, optional
Maximum power for the y-axis. Default: data maximum.
xlines : float or list of float, optional
Vertical reference lines at these frequencies.
ylines : float or list of float, optional
Horizontal reference lines at these power values.
dB : bool, optional
If ``True``, convert power values to dB (10·log₁₀) before
plotting. Default ``False``.
Returns
-------
None
The plot is rendered inline via Matplotlib.
"""
import matplotlib.pyplot as plt
if ch is None: return
if type( ch ) is not list: ch = [ ch ]
if type( xlines ) is not list and xlines != None: xlines = [ xlines ]
if type( ylines ) is not list and ylines != None: ylines = [ ylines ]
df = df[ df['CH'].isin(ch) ]
if len(df) == 0: return
f = df['F'].to_numpy(dtype=float)
p = df[var].to_numpy(dtype=float)
if dB is True: p = 10*np.log10(p)
cx = df['CH'].to_numpy(dtype=str)
if minp is None: minp = min(p)
if maxp is None: maxp = max(p)
if minf is None: minf = min(f)
if maxf is None: maxf = max(f)
incl = np.zeros(len(df), dtype=bool)
incl[ (f >= minf) & (f <= maxf) ] = True
f = f[ incl ]
p = p[ incl ]
cx = cx[ incl ]
p[ p > maxp ] = maxp
p[ p < minp ] = minp
[ plt.plot(f[ cx == _ch ], p[ cx == _ch ] , label = _ch ) for _ch in ch ]
plt.legend()
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power (dB)')
if xlines is not None: [plt.axvline(_x, linewidth=1, color='gray') for _x in xlines ]
if ylines is not None: [plt.axhline(_y, linewidth=1, color='gray') for _y in ylines ]
plt.show()
# --------------------------------------------------------------------------------
[docs]
def spec(df , ch = None , var = 'PSD' , mine = None , maxe = None , minf = None, maxf = None, w = 0.025 ):
"""Plot an epoch-by-frequency spectrogram from a Luna ``CH_E_F`` result table.
Parameters
----------
df : pandas.DataFrame
Result table with columns ``'E'`` (epoch), ``'F'`` (frequency),
``'CH'`` (channel), and the power variable named by *var*.
Typically the ``'PSD: CH_E_F'`` or ``'MTM: CH_E_F'`` table.
ch : str, optional
Channel to plot. If ``None``, all channels in *df* are included.
var : str, optional
Column name for power values. Default ``'PSD'``.
mine : int, optional
First epoch to display. Default: first epoch in the data.
maxe : int, optional
Last epoch to display. Default: last epoch in the data.
minf : float, optional
Minimum frequency (Hz). Default: data minimum.
maxf : float, optional
Maximum frequency (Hz). Default: data maximum.
w : float, optional
Winsorisation proportion applied to power values before colour
mapping. Default ``0.025``.
Returns
-------
None
The spectrogram is rendered inline via Matplotlib.
"""
from scipy.stats.mstats import winsorize
if ch is not None: df = df.loc[ df['CH'] == ch ]
if len(df) == 0: return
x = df['E'].to_numpy(dtype=int)
y = df['F'].to_numpy(dtype=float)
z = df[ var ].to_numpy(dtype=float)
if mine is None: mine = min(x)
if maxe is None: maxe = max(x)
if minf is None: minf = min(y)
if maxf is None: maxf = max(y)
incl = np.zeros(len(df), dtype=bool)
incl[ (x >= mine) & (x <= maxe) & (y >= minf) & (y <= maxf) ] = True
x = x[ incl ]
y = y[ incl ]
z = z[ incl ]
z = winsorize( z , limits=[w, w] )
#include/exclude here...
spec0( x,y,z,mine,maxe,minf,maxf)
# --------------------------------------------------------------------------------
[docs]
def spec0( x , y , z , mine , maxe , minf, maxf ):
"""Render a 2-D spectrogram heatmap from raw epoch/frequency/value vectors.
Low-level helper called by :func:`spec`. Bins *z* values into an
epoch × frequency grid and displays the result as a ``pcolormesh``
plot.
Parameters
----------
x : array-like of int
Epoch index for each observation.
y : array-like of float
Frequency (Hz) for each observation.
z : array-like of float
Power value for each observation.
mine : int
Minimum epoch index for the x-axis.
maxe : int
Maximum epoch index for the x-axis.
minf : float
Minimum frequency (Hz) for the y-axis.
maxf : float
Maximum frequency (Hz) for the y-axis.
Returns
-------
None
The heatmap is rendered inline via Matplotlib.
"""
import matplotlib.pyplot as plt
xn = max(x) - min(x) + 1
yn = np.unique(y).size
zi, yi, xi = np.histogram2d(y, x, bins=(yn,xn), weights=z, density=False )
counts, _, _ = np.histogram2d(y, x, bins=(yn,xn))
with np.errstate(divide='ignore', invalid='ignore'):
zi = zi / counts
zi = np.ma.masked_invalid(zi)
fig, ax = plt.subplots()
fig.set_figheight(2)
fig.set_figwidth(15)
ax.set_xlabel('Epoch')
ax.set_ylabel('Frequency (Hz)')
ax.set(xlim=(mine, maxe), ylim=(minf,maxf) )
p1 = ax.pcolormesh(xi, yi, zi, cmap = 'turbo' )
fig.colorbar(p1)
ax.margins(0.1)
plt.show()
# --------------------------------------------------------------------------------
[docs]
def topo_heat(chs, z, ths = None , th=0.05 ,
topo = None ,
lmts= None , sz=70, colormap = "bwr", title = "",
rimcolor="black", lab = "dB"):
"""Plot a channel-wise topographic heat map on a scalp electrode layout.
Renders a scatter plot in electrode space where each dot is coloured
by the scalar value *z* for that channel. Channels with an associated
p-value below *th* are drawn with a thicker rim.
Parameters
----------
chs : array-like of str
Channel labels corresponding to each value in *z*.
z : array-like of float
Scalar values to colour-map (one per channel in *chs*).
ths : array-like of float, optional
P-values (or thresholding values) for each channel. Channels
with ``ths < th`` receive a highlighted rim. Default ``None``
(no thresholding).
th : float, optional
Significance threshold applied to *ths*. Default ``0.05``.
topo : pandas.DataFrame, optional
Electrode coordinate table with columns ``['CH', 'X', 'Y']``.
Defaults to the 64-channel layout from :func:`default_xy`.
lmts : list of two float, optional
``[vmin, vmax]`` colour-map limits. Default: ``[min(z), max(z)]``.
sz : float, optional
Marker size (points²) for each electrode dot. Default ``70``.
colormap : str, optional
Matplotlib colour map name. Default ``'bwr'``.
title : str, optional
Text label placed in the upper-left of the figure. Default ``''``.
rimcolor : str, optional
Edge colour for all electrode markers. Default ``'black'``.
lab : str, optional
Colour-bar label. Default ``'dB'``.
Returns
-------
None
The topoplot is rendered inline via Matplotlib.
"""
import matplotlib.pyplot as plt
z = np.array(z)
if ths is not None: ths = np.array(ths)
if topo is None: topo = default_xy()
xlim = [-0.6, 0.6]
ylim = [-0.6, 0.6]
rng = [np.min(z), np.max(z)]
if lmts is None : lmts = rng
else: assert lmts[0] <= rng[0] <= lmts[1] and lmts[0] <= rng[1] <= lmts[1], "channel values are out of specified limits"
assert len(set(topo['CH']).intersection(chs)) > 0, "no matching channels"
chs = chs.apply(lambda x: x.upper())
topo = topo[topo['CH'].isin(chs)]
topo["vals"] = np.nan
topo["th_vals"] = np.nan
topo["rims"] = 0.5
for ix, ch in topo.iterrows():
topo.loc[ix,'vals'] = z[chs == ch["CH"]]
if ths is None:
topo.loc[ix,'th_vals'] = 999;
else:
topo.loc[ix,'th_vals'] = ths[chs == ch["CH"]]
if topo.loc[ix,'th_vals'] < th:
topo.loc[ix,'rims'] = 1.5
fig, ax = plt.subplots()
sc = ax.scatter(topo.loc[:,"X"], topo.loc[:,"Y"],cmap=colormap,
c=topo.loc[:, "vals"], edgecolors=rimcolor,
linewidths=topo['rims'], s=sz, vmin=lmts[0], vmax=lmts[1])
plt.text(-0.4, 0.5, s=title, fontsize=10, ha='center', va='center')
plt.text(0.15, -0.48, s=np.round(lmts[0], 2), fontsize=8, ha='center', va='center')
plt.text(0.53, -0.48, s=np.round(lmts[1], 2), fontsize=8, ha='center', va='center')
plt.text(0.35, -0.47, s=lab, fontsize=10, ha='center', va='center')
plt.xlim(xlim)
plt.ylim(ylim)
plt.axis('off')
cax = fig.add_axes([0.6, 0.15, 0.25, 0.02]) # [x, y, width, height]
plt.colorbar(sc, cax=cax, orientation='horizontal')
plt.axis('off')
# arguments
#topo = default_xy()
#ch_names = topo.loc[:, "CH"] # vector of channel names
#ch_vals = np.random.uniform(0, 3, size=len(ch_names))
#ch_vals[0:3] = -18
#th_vals = np.random.uniform(0.06, 1, size=len(ch_names)) # vector of channel values
#th_vals[ch_names == "O2"] = 0
#lmts=[-4, 4]#"default"
#ltopo_heat(ch_names, ch_vals, th_vals = th_vals, th=0.05,
# lmts=lmts, sz=70,
# colormap = "bwr", title = "DENSITY",
# rimcolor="black", lab = "n/min")
[docs]
def scope( p,
chs = None,
bsigs = None ,
hsigs = None,
anns = None ,
stgs = [ 'N1' , 'N2' , 'N3' , 'R' , 'W' , '?' , 'L' ] ,
stgcols = { 'N1':'blue' , 'N2':'blue', 'N3':'navy','R':'red','W':'green','?':'gray','L':'yellow' } ,
stgns = { 'N1':-1 , 'N2':-2, 'N3':-3,'R':0,'W':1,'?':2,'L':2 } ,
sigcols = None,
anncols = None,
throttle1_sr = 100 ,
throttle2_np = 5 * 30 * 100 ,
summary_mins = 30 ,
height = 600 ,
annot_height = 0.15 ,
header_height = 0.04 ,
footer_height = 0.01
):
"""Create the interactive notebook scope viewer for one instance.
Parameters
----------
p : inst
Target Luna instance.
chs, bsigs, hsigs : list of str, optional
Signal channel selections for traces/bands/Hjorth summaries.
anns : list of str, optional
Annotation classes to include.
stgs : list of str, optional
Sleep-stage labels used by hypnogram rendering.
stgcols, stgns : dict, optional
Stage-to-color and stage-to-numeric maps.
sigcols, anncols : dict, optional
Optional explicit color overrides.
throttle1_sr : int, optional
Input sample-rate throttle.
throttle2_np : int, optional
Output point-count throttle.
summary_mins : int or float, optional
Threshold (minutes) for summary behavior in backend.
height : int, optional
Main scope plot height.
annot_height, header_height, footer_height : float, optional
Relative layout proportions.
Returns
-------
ipywidgets.AppLayout or None
Widget application, or ``None`` when no valid channels/annots exist.
"""
import plotly.graph_objects as go
import plotly.express as px
from ipywidgets import widgets, AppLayout
from itertools import cycle
# defaults
scope_epoch_sec = 30
# internally, we use 'sigs' but 'chs' is a more lunapi-consistent label
sigs = chs
# all signals/annotations present
all_sigs = p.edf.channels()
all_annots = p.edf.annots()
# units
hdr = p.headers()
units = dict( zip( hdr.CH , hdr.PDIM ) )
# defaults
if sigs is None: sigs = all_sigs
if bsigs is None: bsigs = p.var( 'eeg' ).split(",")
if hsigs is None: hsigs = p.var( 'eeg' ).split(",")
if anns is None: anns = all_annots
# ensure we do not have weird channels
sigs = [x for x in all_sigs if x in sigs]
bsigs = [x for x in sigs if x in bsigs ]
hsigs = [x for x in sigs if x in hsigs ]
anns = [x for x in all_annots if x in anns ]
sig2n = dict( zip( sigs , list(range(0,len(sigs)))) )
# empty?
if len( sigs ) == 0 and len( anns ) == 0:
print( 'No valid channels or annotations to display')
return None
# initiate segment-serverns
ss = segsrv( p )
ss.calc_bands( bsigs )
ss.calc_hjorths( hsigs )
if type( throttle1_sr ) is int: ss.input_throttle( throttle1_sr )
if type( throttle2_np ) is int: ss.throttle( throttle2_np )
if type( summary_mins ) is int or type( summary_mins ) is float: ss.summary_threshold_mins( summary_mins )
ss.populate( chs = sigs , anns = anns )
# some key variables
nsecs_clk = ss.num_seconds_clocktime_original()
epoch_max = int( nsecs_clk / scope_epoch_sec )
# color palette
pcyc = cycle(px.colors.qualitative.Bold)
palette = dict( zip( sigs , [ next(pcyc) for i in list(range(0,len(sigs))) ] ) )
apalette = dict( zip( anns , [ next(pcyc) for i in list(range(0,len(anns))) ] ) )
# update w/ any user-specified cols, from anncols = { 'ann':'col' }
if sigcols is not None:
for key, value in sigcols.items(): palette[ key ] = value
if stgcols is not None:
for key, value in stgcols.items(): apalette[ key ] = value
if anncols is not None:
for key, value in anncols.items(): apalette[ key ] = value
# define widgets
wlay1 = widgets.Layout( width='95%' )
# channel selection box
chlab = widgets.Label( value = 'Channels:' )
chbox = widgets.SelectMultiple( options=sigs, value=sigs, rows=7, description='', disabled=False , layout = wlay1 )
if len(bsigs) != 0: pow_sel = widgets.Dropdown( options = bsigs, value=bsigs[0],description='',disabled=False,layout = wlay1 )
else: pow_sel = widgets.Dropdown( options = bsigs, value=None,description="Band power:",disabled=False,layout = wlay1 )
band_hjorth_sel = widgets.Checkbox( value = True , description = 'Hjorth' , disabled=False, indent=False )
# annotations (display)
anlab = widgets.Label( value = 'Annotations:' )
anbox = widgets.SelectMultiple( options=anns , value=[], rows=3, description='', disabled=False , layout = wlay1 )
# annotations (instance list/navigation)
a1lab = widgets.Label( value = 'Instances:' )
ansel = widgets.SelectMultiple( options=anns , value=[], rows=3, description='', disabled=False , layout = wlay1 )
a1box = widgets.Select( options=[None] , value=None, rows=3, description='', disabled=False , layout = wlay1 )
# time display labels
tbox = widgets.Label( value = 'T: ' )
tbox2 = widgets.Label( value = '' )
tbox3 = widgets.Label( value = '' )
# misc buttons
reset_button = widgets.Button( description='Reset', disabled=False,button_style='',tooltip='',layout=widgets.Layout(width='98%') )
keep_xscale = widgets.Checkbox( value = False , description = 'Fixed int.' , disabled=False, indent=False )
show_ranges = widgets.Checkbox( value = True , description = 'Units' , disabled=False, indent=False )
# naviation: main slider (top)
smid = widgets.IntSlider(min=scope_epoch_sec/2, max=nsecs_clk - scope_epoch_sec/2, value=scope_epoch_sec/2, step=30, description='', readout=False,layout=widgets.Layout(width='100%') )
# left panel buttons: interval width
swid_label = widgets.Label( value = 'Width' )
swid_dec_button = widgets.Button( description='<', disabled=False,button_style='',tooltip='', layout=widgets.Layout(width='30px'))
swid = widgets.Label( value = '30' )
swid_inc_button = widgets.Button( description='>', disabled=False,button_style='',tooltip='', layout=widgets.Layout(width='30px'))
# left panel buttons: left/right advances
epoch_label = widgets.Label( value = 'Epoch' )
epoch_dec_button = widgets.Button( description='<', disabled=False,button_style='',tooltip='', layout=widgets.Layout(width='30px'))
epoch = widgets.Label( value = '1' )
epoch_inc_button = widgets.Button( description='>', disabled=False,button_style='',tooltip='', layout=widgets.Layout(width='30px'))
# left panel buttons: Y-spacing
yspace_label = widgets.Label( value = 'Space' )
yspace_dec_button = widgets.Button( description='<', disabled=False,button_style='',tooltip='', layout=widgets.Layout(width='30px'))
yspace = widgets.Label( value = '1' )
yspace_inc_button = widgets.Button( description='>', disabled=False,button_style='',tooltip='', layout=widgets.Layout(width='30px'))
# left panel buttons: Y-scaling
yscale_label = widgets.Label( value = 'Scale' )
yscale_dec_button = widgets.Button( description='<', disabled=False,button_style='',tooltip='', layout=widgets.Layout(width='30px'))
yscale = widgets.Label( value = '0' )
yscale_inc_button = widgets.Button( description='>', disabled=False,button_style='',tooltip='', layout=widgets.Layout(width='30px'))
# --------------------- signal plotter (g)
# traces (xNS), gaps(x1), labels (xNS), annots(xNA), clock-ticks(x1)
fig = [go.Scatter(x = None,
y = None,
mode = 'lines',
line=dict(color=palette[sig], width=1),
hoverinfo='none',
name = sig ) for sig in sigs
] + [ go.Scatter( x = None , y = None ,
mode = 'lines' ,
fill='toself' ,
fillcolor='#223344',
line=dict(color='#888888', width=1),
hoverinfo='none',
name='Gap' )
] + [ go.Scatter( x = None , y = None ,
mode='text' ,
textposition='middle right',
textfont=dict(
size=11,
color='white'),
hoverinfo='none' ,
showlegend=False ) for sig in sigs
] + [ go.Scatter( x = None ,
y = None ,
mode = 'lines',
fill='toself',
line=dict(color=apalette[ann], width=1),
hoverinfo='none',
name = ann ) for ann in anns
] + [ go.Scatter( x = None , y = None ,
mode = 'text' ,
textposition='bottom right',
textfont=dict(
size=11,
color='white'),
hoverinfo='none' ,
showlegend=False ) ]
layout = go.Layout( margin=dict(l=8, r=8, t=0, b=0),
yaxis=dict(range=[0,1]),
modebar={'orientation': 'v','bgcolor': '#E9E9E9','color': 'white','activecolor': 'white' },
yaxis_visible=False,
yaxis_showticklabels=False,
xaxis_visible=False,
xaxis_showticklabels=False,
autosize=True,
height=height,
plot_bgcolor='rgb(02,15,50)' )
g = go.FigureWidget(data=fig, layout= layout )
g._config = g._config | {'displayModeBar': False}
#g.update_xaxes(showgrid=True, gridwidth=0.1, gridcolor='#445555')
g.update_xaxes(showgrid=False)
g.update_yaxes(showgrid=False)
# -------------------- segment-plotter (sg)
num_epochs = ss.num_epochs()
tscale = ss.get_time_scale()
tstarts = [ tscale[idx] for idx in range(0,len(tscale),2)]
tstops = [ tscale[idx] for idx in range(1,len(tscale),2)]
times = np.concatenate((tstarts, tstops), axis=1)
# upper/lower boxes, then frame select, then actual segs
sfig = [ go.Scatter(x=[0,0],y=[0.05,0.05],
mode='markers+lines',
marker=dict(color="navy",size=8))
] + [ go.Scatter(x=[0,0],y=[0.95,0.95],
mode='markers+lines',
marker=dict(color="navy",size=8))
] + [ go.Scatter(x=[0,0,0,0,0,None],y=[0,0,1,1,0,None],
mode='lines',
fill='toself',
fillcolor = 'rgba( 18, 65, 92, 0.75)' ,
line=dict(color="red",width=0.5))
] + [ go.Scatter(x=[x[1],x[1],x[3],x[3]],y=[0,1,1,0], # was 0 1 3 2
fill="toself",
mode = 'lines',
hoverinfo = 'none',
line=dict(color='rgb(19,114,38)', width=1), ) for x in times ]
slayout = go.Layout( margin=dict(l=8, r=8, t=2, b=4),
showlegend=False,
xaxis=dict(range=[0,1]),
yaxis=dict(range=[0,1]),
yaxis_visible=False,
yaxis_showticklabels=False,
xaxis_visible=False,
xaxis_showticklabels=False,
autosize=True,
height=15,
plot_bgcolor='rgb(255,255,255)' )
sg = go.FigureWidget( data=sfig, layout=slayout )
sg._config = sg._config | {'displayModeBar': False}
# --------------------- hypnogram-level summary
stgs = [ 'N1' , 'N2' , 'N3' , 'R' , 'W' , '?' , 'L' ]
stgcols = { 'N1':'rgba(32, 178, 218, 1)' , 'N2':'blue', 'N3':'navy','R':'red','W':'green','?':'gray','L':'yellow' }
stgns = { 'N1':-1 , 'N2':-2, 'N3':-3,'R':0,'W':1,'?':2,'L':2 }
# clock-time stage info (in units no larger than 30 seconds)
stg_evts = p.fetch_annots( stgs , 30 )
if len( stg_evts ) != 0:
stg_evts2 = stg_evts.copy()
stg_evts2[ 'Start' ] = stg_evts2[ 'Stop' ]
stg_evts[ 'IDX' ] = range(len(stg_evts))
stg_evts2[ 'IDX' ] = range(len(stg_evts))
stg_evts = pd.concat( [stg_evts2, stg_evts] )
stg_evts = stg_evts.sort_values(by=['Start', 'IDX'])
times = stg_evts['Start'].to_numpy()
ys = [ stgns[c] for c in stg_evts['Class'].tolist() ]
cols = [ stgcols[c] for c in stg_evts['Class'].tolist() ]
else:
times = None
ys = None
cols = None
hypfig = [ go.Scatter( x = times, y=ys, mode='lines', line=dict(color='gray')) ]
hypfig.append( go.Scatter(x = times,
y = ys ,
mode = 'markers' ,
marker=dict( color = cols , size=2),
hoverinfo='none' ) )
hyplayout = go.Layout( margin=dict(l=8, r=8, t=0, b=0),
showlegend=False,
xaxis=dict(range=[0,nsecs_clk]),
yaxis=dict(range=[-4,3]),
yaxis_visible=False,
yaxis_showticklabels=False,
xaxis_visible=False,
xaxis_showticklabels=False,
autosize=True,
height=35,
plot_bgcolor='rgb(255,255,255)' )
hypg = go.FigureWidget( data = hypfig , layout = hyplayout )
hypg._config = hypg._config | {'displayModeBar': False}
# --------------------- band power/spectrogram (bg)
#bfig = go.Heatmap( z = None , type = 'heatmap', colorscale = 'RdBu_r', showscale = False , hoverinfo = 'none' )
bfig = go.Heatmap( z = None , type = 'heatmap', colorscale = 'turbo', showscale = False , hoverinfo = 'none' )
blayout = go.Layout( margin=dict(l=8, r=8, t=0, b=0),
modebar={'orientation': 'h','bgcolor': '#E9E9E9','color': 'white','activecolor': 'white' },
showlegend=False,
yaxis_visible=False,
yaxis_showticklabels=False,
xaxis_visible=False,
xaxis_showticklabels=False,
autosize=True,
height=50,
plot_bgcolor='rgb(255,255,255)' )
bg = go.FigureWidget( bfig , blayout )
bg._config = bg._config | {'displayModeBar': False}
# --------------------- build overall box (containerP)
# ----- containers - left panel
ctr_lab_container = widgets.VBox(children=[ swid_label , epoch_label, yspace_label , yscale_label ] ,
layout = widgets.Layout( width='30%', align_items='center' , display='flex', flex_flow='column' ) )
ctr_dec_container = widgets.VBox(children=[ swid_dec_button , epoch_dec_button, yspace_dec_button , yscale_dec_button ] ,
layout = widgets.Layout( width='20%', align_items='center' , display='flex', flex_flow='column' ))
ctr_val_container = widgets.VBox(children=[ swid , epoch , yspace , yscale ] ,
layout = widgets.Layout( width='30%', align_items='center' , display='flex', flex_flow='column' ))
ctr_inc_container = widgets.VBox(children=[ swid_inc_button , epoch_inc_button, yspace_inc_button , yscale_inc_button ] ,
layout = widgets.Layout( width='20%', align_items='center' , display='flex', flex_flow='column' ))
# left panel: group top set of widgets
ctr_container = widgets.VBox( children=[ tbox, widgets.HBox(children=[ ctr_lab_container, ctr_dec_container, ctr_val_container, ctr_inc_container ] ) , reset_button ] ,
layout = widgets.Layout( width='100%' ) )
# left panel: lower buttons
lower_buttons = widgets.HBox( children=[ keep_xscale , show_ranges ] ,
layout = widgets.Layout( width='100%' ) )
# left panel: construct all
left_panel = widgets.VBox(children=[ ctr_container,
chlab, chbox,
widgets.HBox( children = [ band_hjorth_sel, pow_sel ] ),
anlab, anbox, a1lab, ansel, a1box,
lower_buttons ] ,
layout = widgets.Layout( width='95%' , margin='0 0 0 5px' , overflow_x = 'hidden' ) )
# right panel: combine plots
containerS = widgets.VBox(children=[ smid , hypg, sg, bg, g ] , layout = widgets.Layout( width='95%' , margin='0 5px 0 5px' , overflow_x = 'hidden' ) )
# make the final app (just join left+right panels)
container_app = AppLayout(header=None,
left_sidebar=left_panel,
center=containerS,
right_sidebar=None,
pane_widths=[1, 8, 0],
align_items = 'stretch' ,
footer=None , layout = widgets.Layout( border='3px none #708090' , margin='10px 5px 10px 5px' , overflow_x = 'hidden' ) )
# --------------------- callback functions
def redraw():
# update hms message
tbox.value = 'T: ' + ss.get_window_left_hms() + ' - ' + ss.get_window_right_hms()
# get annots
ss.compile_windowed_annots( anbox.value )
x1 = ss.get_window_left()
x2 = ss.get_window_right()
# update pointers on segment plot
s1 = x1 / nsecs_clk
s2 = x2 / nsecs_clk
sg.data[0].x = [ s1, s2 ]
sg.data[1].x = [ s1, s2 ]
sg.data[2].x = [ s1 , s2 , s2 , s1 , s1 , None ]
# update main plot
with g.batch_update():
ns = len(sigs)
na = len(anns)
# axes
g.update_xaxes(range = [x1,x2])
# signals (0)
selected = [ x in chbox.value for x in sigs ]
idx=0
for i in list(range(0,ns)):
if selected[i] is True:
g.data[i].x = ss.get_timetrack( sigs[i] )
g.data[i].y = ss.get_scaled_signal( sigs[i] , idx )
g.data[i].visible = True
idx += 1
else:
g.data[i].visible = False
# gaps (last trace)
gidx = ns
gaps = list( ss.get_gaps() )
if len(gaps) == 0:
g.data[ gidx ].visible = False
else:
# make into 6-value formats
xgaps = [(a, b, b, a, a, None ) for a, b in gaps ]
ygaps = [(0, 0, 1-header_height, 1-header_height, 0, None ) for a, b in gaps ]
g.data[ gidx ].x = [x for sub in xgaps for x in sub]
g.data[ gidx ].y = [y for sub in ygaps for y in sub]
g.data[ gidx ].visible = True
# ranges? (+ns)
if show_ranges.value is True:
idx=0
xl = x1 + (x2-x1 ) * 0.01
for i in list(range(0,ns)):
if selected[i] is True:
ylim = ss.get_window_phys_range( sigs[i] )
ylab = sigs[i] + ' ' + str(round(ylim[0],3)) + ':' + str(round(ylim[1],3)) + ' (' + units[sigs[i]] +')'
g.data[i+ns+1].x = [ xl ]
g.data[i+ns+1].y = [ ss.get_ylabel( idx ) * (1 - header_height ) ]
g.data[i+ns+1].text = [ ylab ]
g.data[i+ns+1].visible = True
idx += 1
else:
g.data[i+ns+1].visible = False
# annots (+2ns + gap)
ns2 = 2 * ns + 1
selected = [ x in anbox.value for x in anns ]
for i in list(range(0,na)):
if selected[i] is True:
g.data[i+ns2].x = ss.get_annots_xaxes( anns[i] )
g.data[i+ns2].y = ss.get_annots_yaxes( anns[i] )
g.data[i+ns2].visible = True
else:
g.data[i+ns2].visible = False
# clock-ticks
gidx = 2 * ns + na + 1
tks = ss.get_clock_ticks(6)
tx = list( tks.keys() )
tv = list( tks.values() )
if len( tx ) == 0:
g.data[ gidx ].visible = False
else:
g.data[ gidx ].x = tx
g.data[ gidx ].y = [ 1 - header_height + ( header_height ) * 0.5 for x in tx ]
g.data[ gidx ].text = tv
g.data[ gidx ].visible = True
def rescale(change):
ss.set_scaling( len(chbox.value) , len( anbox.value) , 2**float(yscale.value) , float(yspace.value) , header_height, footer_height , annot_height )
redraw()
def update_bandpower(change):
if pow_sel.value is None: return
if len( pow_sel.value ) == 0: return
if band_hjorth_sel.value is True:
S = np.transpose( ss.get_hjorths( pow_sel.value ) )
S = np.asarray(S,dtype=object)
S[np.isnan(S.astype(np.float64))] = None
bg.update_traces({'z': S } , selector = {'type':'heatmap'} )
else:
S = np.transpose( ss.get_bands( pow_sel.value ) )
S = np.asarray(S,dtype=object)
S[np.isnan(S.astype(np.float64))] = None
bg.update_traces({'z': S } , selector = {'type':'heatmap'} )
def pop_a1(change):
a1box.options = ss.get_all_annots( ansel.value )
def a1_win(change):
# format <annot> | t1-t2 (seconds)
# allow for pipe in <annot> name
nwin = a1box.value.split( '| ')[-1]
nwin = nwin.split('-')
nwin = [ float(x) for x in nwin ]
# center on mid of annot
mid = nwin[0] + ( nwin[1] - nwin[0] ) / 2
# width: either based on annot, or keep as is
if keep_xscale.value is False:
swid.unobserve(set_window_from_sliders, names="value")
swid.value = str( round( nwin[1] - nwin[0] , 2 ) )
swid.observe(set_window_from_sliders, names="value")
# update smid, and trigger redraw via set_window_from_sliders()
smid.value = mid
def set_window_from_sliders(change):
w = float( swid.value )
p1 = smid.value - 0.5 * w
if p1 < 0: p1 = 0
p2 = p1 + w
if p2 >= ss.num_seconds_clocktime():
p2 = ss.num_seconds_clocktime() - 1
ss.window( p1 , p2 )
epoch.value = str(1+int(smid.value/30))
redraw()
def fn_reset(b):
swid.value = str( 30 )
yspace.value = str( 1 )
yscale.value = str( 0 )
def fn_dec_epoch(b):
if ( smid.value - scope_epoch_sec ) >= smid.min :
smid.value = smid.value - scope_epoch_sec
def fn_inc_epoch(b):
if ( smid.value + scope_epoch_sec ) <= smid.max :
smid.value = smid.value + scope_epoch_sec
def fn_dec_swid(b):
swid_var = float( swid.value )
if swid_var > 3.5: swid_var = swid_var / 2
if swid_var > 100: swid.value = str( int( swid_var ))
else: swid.value = str( swid_var )
def fn_inc_swid(b):
swid_var = float( swid.value )
if swid_var < 40000: swid_var = swid_var * 2
if swid_var > 100: swid.value = str( int( swid_var ) )
else: swid.value = str( swid_var )
def fn_yspace_dec(b):
yspace_var = float( yspace.value )
if yspace_var > 0.05: yspace_var = yspace_var - 0.1
yspace.value = str( round( yspace_var , 1 ) )
def fn_yspace_inc(b):
yspace_var = float( yspace.value )
if yspace_var < 0.95: yspace_var = yspace_var + 0.1
yspace.value = str( round( yspace_var , 1 ) )
def fn_yscale_dec(b):
yscale_var = float( yscale.value )
if yscale_var > -2: yscale_var = yscale_var - 0.2
yscale.value = str( round( yscale_var , 1 ) )
def fn_yscale_inc(b):
yscale_var = float( yscale.value )
if yscale_var < 2: yscale_var = yscale_var + 0.2
yscale.value = str( round( yscale_var , 1 ) )
def fn_hjorth_band(b):
if band_hjorth_sel.value is True:
pow_sel.options = hsigs
else:
pow_sel.options = bsigs
# --------------------- hook up widgets
# observers
smid.observe(set_window_from_sliders, names="value")
swid.observe(set_window_from_sliders, names="value")
show_ranges.observe(set_window_from_sliders)
band_hjorth_sel.observe( fn_hjorth_band )
swid_dec_button.on_click(fn_dec_swid)
swid_inc_button.on_click(fn_inc_swid)
epoch_dec_button.on_click(fn_dec_epoch)
epoch_inc_button.on_click(fn_inc_epoch)
reset_button.on_click(fn_reset)
# summaries
pow_sel.observe(update_bandpower,names="value")
# rescale plots
yscale_dec_button.on_click( fn_yscale_dec )
yscale_inc_button.on_click( fn_yscale_inc )
yspace_dec_button.on_click( fn_yspace_dec )
yspace_inc_button.on_click( fn_yspace_inc )
yscale.observe( rescale , names="value")
yspace.observe( rescale , names="value")
# channel selection
chbox.observe( rescale ,names="value")
# annots
anbox.observe( rescale , names="value")
ansel.observe( pop_a1 , names="value")
a1box.observe( a1_win , names="value")
# --------------------- display
update_bandpower(None)
ss.set_scaling( len(chbox.value) , len( anbox.value) , 2**float(yscale.value) , float(yspace.value) , header_height, footer_height , annot_height )
ss.window( 0 , 30 )
epoch.value = str(1);
redraw()
return container_app
__all__ = [
"default_xy",
"stgcol",
"stgn",
"hypno",
"hypno_density",
"psd",
"spec",
"spec0",
"topo_heat",
"scope",
]