Skip to article content

NeuroMOSAICS: A collection of neurostimulation datasets - Multi-scale Open-Source Across Interfaces Conditions & Species.

from import_data_osf import get_data
import numpy as np
%matplotlib widget
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize 
from matplotlib.widgets import Button
from matplotlib.patches import FancyArrow
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pickle
# Function to generate heatmap data
def generate_heatmap_data(iMuscle):
    data_matrix = np.zeros(maps.shape)
    for i in range(ch2xy.shape[0]) :
        data_matrix[int(ch2xy[i][0])][int(ch2xy[i][1])] = sorted_respMean[i][iMuscle]
    return data_matrix  # Generate random heatmap data

# Function to create or update the detailed figure
def update_detailed_figure(iArray, iMuscle, detailed_figure, detailed_axes, d):
    global current_cbar, last_button_index
    
    for key, value in d.items():
        globals()[key] = value

    # Clear previous content
    for ax in detailed_axes:
        ax.clear()

    # Clear all previous content from the figure
    for ax in np.ravel(detailed_axes):  
        ax.cla() 
        
    """if current_cbar is not None:
        try:
            current_cbar.remove()
            current_cbar = None
        except Exception as e:
            print(f"Error removing colorbar: {e}")"""
    if current_cbar is not None:
        detailed_figure.delaxes(current_cbar.ax)
    
    n_repetitions = np.where(stim_channel == iArray + 1)[0].shape[0]
    upLim = np.nanmax(np.mean(sorted_evoked[:,last_button_index,:n_repetitions],1)) * 1000
    upCounts = max([np.where(stim_channel == i)[0].shape[0] for i in range(ch2xy.shape[0])])
    # First imshow 
    #minSorted = min(np.min(sorted_evoked[iArray, iMuscle, :n_repetitions, :]), np.min(sorted_filtered[iArray, iMuscle, :n_repetitions, :])) * 1000
    #maxSorted = max(np.max(sorted_evoked[iArray, iMuscle, :n_repetitions, :]), np.max(sorted_filtered[iArray, iMuscle, :n_repetitions, :])) * 1000
    minSorted = np.nanmin(sorted_evoked[:, iMuscle, :n_repetitions, :]) * 1000
    maxSorted = np.nanmax(sorted_evoked[:, iMuscle, :n_repetitions, :]) * 1000
    im1 =  detailed_axes[0].imshow(sorted_evoked[iArray, iMuscle, :n_repetitions, :] * 1000, vmin=minSorted, vmax=maxSorted, aspect='auto', cmap='Blues')
    # Overlay isvalid == 0 in red
    normalize = Normalize(vmin=minSorted, vmax=maxSorted)
    for row in np.where(sorted_isvalid[iArray, iMuscle] == 0)[0]:
        detailed_axes[0].scatter(range(sorted_evoked[iArray, iMuscle, :n_repetitions, :].shape[1]), [row] * sorted_evoked[iArray, iMuscle, :n_repetitions, :].shape[1], 
                    c=sorted_evoked[iArray, iMuscle, :n_repetitions, :][row, :] * 1000, cmap='Reds', norm=normalize, marker='s', s=10)
        
    divider = make_axes_locatable(ax)
    # cax = divider.append_axes("right", size="5%", pad=0.05)
    x_0 = np.where(time == 0)[0][0]
    step = time.shape[0]//5
    detailed_axes[0].set_xticks([i for i in range(x_0%step,time.shape[0], step)])
    detailed_axes[0].set_xticklabels([np.round(time[i],2) for i in range(x_0%step,time.shape[0],  step)])
    detailed_axes[0].set_yticks([i for i in range(n_repetitions)])
    # plt.axvline(resp_region[0] * 1000, color="orange")
    detailed_axes[0].set_xlabel("Time (ms)", fontsize=fontsize_axes)
    detailed_axes[0].set_ylabel("Number of trials", fontsize=fontsize_axes)
    detailed_axes[0].set_title(f"Raw {emgs['emgs'][iMuscle]} EMG stack", fontsize=fontsize_title)
    detailed_axes[0].text(
    0.95, 0.93, "▬ Outliers", fontsize=10, color='red', 
    ha='right', transform=detailed_axes[0].transAxes
    )
    # Second imshow
    im2 =  detailed_axes[1].imshow(sorted_filtered[iArray, iMuscle, :n_repetitions, :] * 1000, vmin=minSorted, vmax=maxSorted, aspect='auto', cmap='Blues')
    # divider = make_axes_locatable(ax)
    # cax = divider.append_axes("right", size="5%", pad=0.05)
    cax = detailed_figure.add_axes([0.91, 0.56, 0.02, 0.35]) #.add_axes([0.5, 0.55, 0.02, 0.35]) #[0.55, 0.15, 0.03, 0.2]
    current_cbar =  detailed_axes[1].figure.colorbar(im2, cax=cax) 
    current_cbar.ax.set_ylabel("MEP (mV)", rotation=-90, va="bottom", fontsize=fontsize_axes)
    detailed_axes[1].set_xticks([i for i in range(x_0%step,time.shape[0], step)])
    detailed_axes[1].set_xticklabels([np.round(time[i],2) for i in range(x_0%step,time.shape[0], step)])
    detailed_axes[1].set_yticks([i for i in range(n_repetitions)])
    # plt.axvline(resp_region[0] * 1000, color="orange")
    detailed_axes[1].set_xlabel("Time (ms)", fontsize=fontsize_axes)
    #detailed_axes[1].set_ylabel("Number of trials", fontsize=fontsize_axes)
    detailed_axes[1].set_title(f"Filtered {emgs['emgs'][iMuscle]} EMG stack", fontsize=fontsize_title)
    

    # Plot
    m = np.mean(sorted_filtered[iArray, iMuscle, :n_repetitions, :], axis=0) * 1000
    std = (np.std(sorted_filtered[iArray, iMuscle, :n_repetitions, :] , axis=0, ddof=1)) / np.sqrt(n_repetitions)*1000
    detailed_axes[2].plot(time, m, label=f"Mean filtered EMG")
    
    
    detailed_axes[2].axvline((resp_region[0] - where_zero)/fs * 1000, color="orange", linestyle="dotted", label="Range to compute maximum peak")
    detailed_axes[2].axvline((resp_region[1] - where_zero)/fs * 1000, color="orange", linestyle="dotted")

    detailed_axes[2].set_ylim(0, upLim)
    detailed_axes[2].set_xlabel("Time (ms)", fontsize=fontsize_axes)
    detailed_axes[2].set_ylabel("MEP (mV)", fontsize=fontsize_axes)
    detailed_axes[2].set_title(f"{emgs['emgs'][iMuscle]} ", fontsize=fontsize_title)
    detailed_axes[2].legend(loc=1, fontsize=fontsize_legend)

    # Histogram
    n, bins, patches = detailed_axes[3].hist(sorted_resp[iArray,iMuscle,:n_repetitions] * 1000 , bins=10, density=False, alpha=0.75)

    detailed_axes[3].set_ylim(0, upCounts)
    detailed_axes[3].set_xlim(0, np.nanmax(sorted_resp[:,iMuscle,:]) * 1000)
    detailed_axes[3].set_xlabel('Peak amplitude', fontsize=fontsize_axes)
    detailed_axes[3].set_ylabel('Counts', fontsize=fontsize_axes)
    detailed_axes[3].set_title(f'Distribution of peak amplitude', fontsize=fontsize_title)
    detailed_axes[3].grid(True)
    
    detailed_figure.suptitle(f'B. MEP details from electrode {iArray}', fontsize=fontsize_title, y=1)
    detailed_figure.subplots_adjust(left=0.1, right=0.9, top=0.92, bottom=0.05, wspace=0.5, hspace=0.4)
    #detailed_figure.tight_layout()
    detailed_figure.canvas.draw()

# Function to update the heatmap based on button click
def on_button_click(event, iMuscle, fig):
    global heatmap_data, last_button_index  # Access the global variable for heatmap data
    last_button_index = iMuscle
    # Generate new data
    heatmap_data = generate_heatmap_data(iMuscle)
    for ax in np.ravel(ax_heatmap):  
        ax.cla()
    im = ax_heatmap.imshow((heatmap_data-np.nanmin(heatmap_data))/(np.nanmax(heatmap_data)-np.nanmin(heatmap_data)), cmap='Blues')
    ax_heatmap.set_title(emgs['emgs'][iMuscle], fontsize=fontsize_title)  
    # fig.tight_layout()
    
    # Set title and redraw the plot
    fig.canvas.draw()
    
# Function to handle cell click
def on_heatmap_click(event, d, detailed_figure, detailed_axes, ax_heatmap):
    # Unpack the dictionary into variables
    for key, value in d.items():
        globals()[key] = value
    if event.inaxes == ax_heatmap:
        x, y = event.xdata, event.ydata
        y = ax_heatmap.get_ylim()[1] - y
        x_idx, y_idx = abs(int(round(x))), abs(int(round(y)))
        # print(x_idx, y_idx)
        # Convert data coordinates to pixel coordinates
        trans_data = ax_heatmap.transData
        x_pixel, y_pixel = trans_data.transform((x, y))
        if 0 <= x_idx < heatmap_data.shape[1] and 0 <= y_idx < heatmap_data.shape[0]:  # Check bounds

            for patch in ax_heatmap.patches:
                patch.remove()
            rect = patches.Rectangle((x_idx - 0.5, y_idx - 0.5), 1, 1, linewidth=2, edgecolor='red', facecolor='none')
            ax_heatmap.add_patch(rect)
            
            cell_value = heatmap_data[y_idx, x_idx]  # Access the heatmap data directly
            iArray = np.intersect1d(np.where(ch2xy[:,1] == x_idx)[0], np.where(ch2xy[:,0] == y_idx)[0])[0]
            # print(f"Cell clicked: ({x_idx}, {y_idx}) with value {cell_value}")  # Debugging statement
            if last_button_index is not None :
                update_detailed_figure(iArray, last_button_index, detailed_figure, detailed_axes, d)
is_mat = False
with open("../data/neuromosaics/osfstorage/rat1/rat1_C5_500uA.pkl" , "rb") as f: 
            data = pickle.load(f)
d1 = get_data(is_mat=is_mat, is_macaque=False, mat=data, name='rat1_C5_500uA')

# Unpack the dictionary into variables
for key, value in d1.items():
    globals()[key] = value

# Define some plotting parameters
where_zero = np.where(abs(stimProfile) > 10**(-50))[0][0]
time = np.array([i/fs for i in range(-int(where_zero), evoked_emg.shape[2] - int(where_zero))]) * 1000
fontsize_title = 15
fontsize_axes = 13
fontsize_legend = 10
maps1 = np.copy(maps)

# PLOT
maps = maps1
current_cbar = None

# Create the main figure and axes
fig_number = len(plt.get_fignums()) + 1  # Total existing figures + 1
fig, (ax_heatmap, ax_buttons) = plt.subplots(1, 2, figsize=(9, 6), gridspec_kw={'width_ratios': [4, 1]}, num=f'Main Plot {fig_number}')
plt.subplots_adjust(left=0.1, right=0.9, wspace=0.4)
fig.suptitle(f'A. Rat motor response heatmap', fontsize=fontsize_title)

# Heatmap data generation and plotting
heatmap_data = generate_heatmap_data(0)

if not is_mat:
    if additional_info['Position_Line_0'].item() == 'Left':
        heatmap_data = np.rot90(heatmap_data, 1)
        maps = np.rot90(maps1, 1)
    if additional_info['Position_Line_0'].item() == 'Right':
        heatmap_data = np.rot90(heatmap_data, 3)
        maps = np.rot90(maps1, 3)
        
    for ch in range(int(parameters['nChan'].iloc[0])):
        x, y = np.where(maps == ch + 1)[0][0], np.where(maps == ch + 1)[1][0]
        ch2xy[ch] = [x, y]

im = ax_heatmap.imshow((heatmap_data - np.nanmin(heatmap_data)) / (np.nanmax(heatmap_data) - np.nanmin(heatmap_data)), cmap='Blues')
divider = make_axes_locatable(ax_heatmap)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = ax_heatmap.figure.colorbar(im, cax=cax)
cbar.ax.set_ylabel("Normalized MEP (mV)", rotation=-90, va="bottom", fontsize=fontsize_axes)

x = 5
y = 1
array10 = np.intersect1d(np.where(ch2xy[:,1] == x)[0], np.where(ch2xy[:,0] == y)[0])[0]
rect = patches.Rectangle((x - 0.5, y - 0.5), 1, 1, linewidth=2, edgecolor='red', facecolor='none')
ax_heatmap.add_patch(rect)

if not is_mat:
    compass_x, compass_y = 0.1, 0.9
    fig.patches.extend([
        FancyArrow(compass_x, compass_y, -0.03, 0, color='black', transform=fig.transFigure,
                   width=0.005, head_width=0.02, head_length=0.02)
    ])
    plt.text(compass_x - 0.075, compass_y, 'Left', color='black', ha='center', va='center', transform=fig.transFigure)
    fig.patches.extend([
        FancyArrow(compass_x, compass_y, 0, 0.03, color='black', transform=fig.transFigure,
                   width=0.005, head_width=0.02, head_length=0.02)
    ])
    plt.text(compass_x, compass_y + 0.06, 'Rostral', color='black', ha='center', va='center', transform=fig.transFigure)
    ax_heatmap.set_title('Array on ' + additional_info['Segment'].item() + '\n' + emgs['emgs'][0], fontsize=fontsize_title)

ax_buttons.set_visible(True)
ax_buttons.spines['top'].set_visible(False)
ax_buttons.spines['right'].set_visible(False)
ax_buttons.spines['left'].set_visible(False)
ax_buttons.spines['bottom'].set_visible(False)
ax_buttons.xaxis.set_visible(False)
ax_buttons.yaxis.set_visible(False)
ax_buttons.set_xticks([])
ax_buttons.set_yticks([])

button_positions = [0.8 - i * 0.06 for i in range(n_muscles)]
button_axes = [plt.axes([0.85, pos, 0.1, 0.05]) for pos in button_positions]
buttons = [Button(button_axes[i], emgs['emgsabr'][i]) for i in range(n_muscles)]
square_symbol = '\u25A0'
plt.text(0.90, 0.88, "Muscle selection", fontsize=12, ha='center', transform=plt.gcf().transFigure)
plt.text(0.90, 0.2, f"{square_symbol} Selected electrode", fontsize=10, color='red', ha='center', transform=plt.gcf().transFigure)
cmap = plt.get_cmap("Blues")
deep_blue = cmap(0.8)  # Dark blue for text
light_blue = cmap(0.1)  # Light blue for background
plt.text(0.7, 0.02, 
    "Activate interactivity\n(top-right)\nThen press Play ▶ and\nclick muscles\nor the heatmap!",
    fontsize=12, color=deep_blue, ha='left', fontfamily='DejaVu Sans',
         bbox=dict(boxstyle="round,pad=0.2", edgecolor='none', facecolor=light_blue, alpha=0.7),
         transform=plt.gcf().transFigure )

for i, button in enumerate(buttons):
    button.on_clicked(lambda event, index=i: on_button_click(event, index, fig))

# Create and show the detailed figure second
detailed_fig_number = len(plt.get_fignums())  # Total existing figures + 1
detailed_figure, detailed_axes = plt.subplots(2, 2, figsize=(9, 6), num=f'Detailed View {detailed_fig_number}')
detailed_axes[1, 0].text(
    0.95, 0.91, "▬ Outliers", fontsize=10, color='red', 
    ha='right', transform=detailed_axes[1,0].transAxes
    )
plt.subplots_adjust(wspace=0.5)
detailed_axes = detailed_axes.flatten()
#detailed_figure.canvas.draw()

fig.canvas.mpl_connect('button_press_event', lambda event: on_heatmap_click(event, d1, detailed_figure, detailed_axes, ax_heatmap))
last_button_index = 0
update_detailed_figure(array10, 0, detailed_figure, detailed_axes, d1)
# Show the main figure first
#plt.show(fig)

#detailed_figure.canvas.draw()
Loading...
NeuroMOSAICS: A collection of neurostimulation datasets - Multi-scale Open-Source Across Interfaces Conditions & Species.
Figure Macaque