NeuroMOSAICS: A collection of neurostimulation datasets - Multi-scale Open-Source Across Interfaces Conditions & Species.
from import_data_osf import get_data
import numpy as np
%matplotlib widget
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.widgets import Button
from matplotlib.patches import FancyArrow
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pickle
# Function to generate heatmap data
def generate_heatmap_data(iMuscle):
data_matrix = np.zeros(maps.shape)
for i in range(ch2xy.shape[0]) :
data_matrix[int(ch2xy[i][0])][int(ch2xy[i][1])] = sorted_respMean[i][iMuscle]
return data_matrix # Generate random heatmap data
# Function to create or update the detailed figure
def update_detailed_figure(iArray, iMuscle, detailed_figure, detailed_axes, d):
global current_cbar, last_button_index
for key, value in d.items():
globals()[key] = value
# Clear previous content
for ax in detailed_axes:
ax.clear()
# Clear all previous content from the figure
for ax in np.ravel(detailed_axes):
ax.cla()
"""if current_cbar is not None:
try:
current_cbar.remove()
current_cbar = None
except Exception as e:
print(f"Error removing colorbar: {e}")"""
if current_cbar is not None:
detailed_figure.delaxes(current_cbar.ax)
n_repetitions = np.where(stim_channel == iArray + 1)[0].shape[0]
upLim = np.nanmax(np.mean(sorted_evoked[:,last_button_index,:n_repetitions],1)) * 1000
upCounts = max([np.where(stim_channel == i)[0].shape[0] for i in range(ch2xy.shape[0])])
# First imshow
#minSorted = min(np.min(sorted_evoked[iArray, iMuscle, :n_repetitions, :]), np.min(sorted_filtered[iArray, iMuscle, :n_repetitions, :])) * 1000
#maxSorted = max(np.max(sorted_evoked[iArray, iMuscle, :n_repetitions, :]), np.max(sorted_filtered[iArray, iMuscle, :n_repetitions, :])) * 1000
minSorted = np.nanmin(sorted_evoked[:, iMuscle, :n_repetitions, :]) * 1000
maxSorted = np.nanmax(sorted_evoked[:, iMuscle, :n_repetitions, :]) * 1000
im1 = detailed_axes[0].imshow(sorted_evoked[iArray, iMuscle, :n_repetitions, :] * 1000, vmin=minSorted, vmax=maxSorted, aspect='auto', cmap='Blues')
# Overlay isvalid == 0 in red
normalize = Normalize(vmin=minSorted, vmax=maxSorted)
for row in np.where(sorted_isvalid[iArray, iMuscle] == 0)[0]:
detailed_axes[0].scatter(range(sorted_evoked[iArray, iMuscle, :n_repetitions, :].shape[1]), [row] * sorted_evoked[iArray, iMuscle, :n_repetitions, :].shape[1],
c=sorted_evoked[iArray, iMuscle, :n_repetitions, :][row, :] * 1000, cmap='Reds', norm=normalize, marker='s', s=10)
divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
x_0 = np.where(time == 0)[0][0]
step = time.shape[0]//5
detailed_axes[0].set_xticks([i for i in range(x_0%step,time.shape[0], step)])
detailed_axes[0].set_xticklabels([np.round(time[i],2) for i in range(x_0%step,time.shape[0], step)])
detailed_axes[0].set_yticks([i for i in range(n_repetitions)])
# plt.axvline(resp_region[0] * 1000, color="orange")
detailed_axes[0].set_xlabel("Time (ms)", fontsize=fontsize_axes)
detailed_axes[0].set_ylabel("Number of trials", fontsize=fontsize_axes)
detailed_axes[0].set_title(f"Raw {emgs['emgs'][iMuscle]} EMG stack", fontsize=fontsize_title)
detailed_axes[0].text(
0.95, 0.93, "▬ Outliers", fontsize=10, color='red',
ha='right', transform=detailed_axes[0].transAxes
)
# Second imshow
im2 = detailed_axes[1].imshow(sorted_filtered[iArray, iMuscle, :n_repetitions, :] * 1000, vmin=minSorted, vmax=maxSorted, aspect='auto', cmap='Blues')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
cax = detailed_figure.add_axes([0.91, 0.56, 0.02, 0.35]) #.add_axes([0.5, 0.55, 0.02, 0.35]) #[0.55, 0.15, 0.03, 0.2]
current_cbar = detailed_axes[1].figure.colorbar(im2, cax=cax)
current_cbar.ax.set_ylabel("MEP (mV)", rotation=-90, va="bottom", fontsize=fontsize_axes)
detailed_axes[1].set_xticks([i for i in range(x_0%step,time.shape[0], step)])
detailed_axes[1].set_xticklabels([np.round(time[i],2) for i in range(x_0%step,time.shape[0], step)])
detailed_axes[1].set_yticks([i for i in range(n_repetitions)])
# plt.axvline(resp_region[0] * 1000, color="orange")
detailed_axes[1].set_xlabel("Time (ms)", fontsize=fontsize_axes)
#detailed_axes[1].set_ylabel("Number of trials", fontsize=fontsize_axes)
detailed_axes[1].set_title(f"Filtered {emgs['emgs'][iMuscle]} EMG stack", fontsize=fontsize_title)
# Plot
m = np.mean(sorted_filtered[iArray, iMuscle, :n_repetitions, :], axis=0) * 1000
std = (np.std(sorted_filtered[iArray, iMuscle, :n_repetitions, :] , axis=0, ddof=1)) / np.sqrt(n_repetitions)*1000
detailed_axes[2].plot(time, m, label=f"Mean filtered EMG")
detailed_axes[2].axvline((resp_region[0] - where_zero)/fs * 1000, color="orange", linestyle="dotted", label="Range to compute maximum peak")
detailed_axes[2].axvline((resp_region[1] - where_zero)/fs * 1000, color="orange", linestyle="dotted")
detailed_axes[2].set_ylim(0, upLim)
detailed_axes[2].set_xlabel("Time (ms)", fontsize=fontsize_axes)
detailed_axes[2].set_ylabel("MEP (mV)", fontsize=fontsize_axes)
detailed_axes[2].set_title(f"{emgs['emgs'][iMuscle]} ", fontsize=fontsize_title)
detailed_axes[2].legend(loc=1, fontsize=fontsize_legend)
# Histogram
n, bins, patches = detailed_axes[3].hist(sorted_resp[iArray,iMuscle,:n_repetitions] * 1000 , bins=10, density=False, alpha=0.75)
detailed_axes[3].set_ylim(0, upCounts)
detailed_axes[3].set_xlim(0, np.nanmax(sorted_resp[:,iMuscle,:]) * 1000)
detailed_axes[3].set_xlabel('Peak amplitude', fontsize=fontsize_axes)
detailed_axes[3].set_ylabel('Counts', fontsize=fontsize_axes)
detailed_axes[3].set_title(f'Distribution of peak amplitude', fontsize=fontsize_title)
detailed_axes[3].grid(True)
detailed_figure.suptitle(f'B. MEP details from electrode {iArray}', fontsize=fontsize_title, y=1)
detailed_figure.subplots_adjust(left=0.1, right=0.9, top=0.92, bottom=0.05, wspace=0.5, hspace=0.4)
#detailed_figure.tight_layout()
detailed_figure.canvas.draw()
# Function to update the heatmap based on button click
def on_button_click(event, iMuscle, fig):
global heatmap_data, last_button_index # Access the global variable for heatmap data
last_button_index = iMuscle
# Generate new data
heatmap_data = generate_heatmap_data(iMuscle)
for ax in np.ravel(ax_heatmap):
ax.cla()
im = ax_heatmap.imshow((heatmap_data-np.nanmin(heatmap_data))/(np.nanmax(heatmap_data)-np.nanmin(heatmap_data)), cmap='Blues')
ax_heatmap.set_title(emgs['emgs'][iMuscle], fontsize=fontsize_title)
# fig.tight_layout()
# Set title and redraw the plot
fig.canvas.draw()
# Function to handle cell click
def on_heatmap_click(event, d, detailed_figure, detailed_axes, ax_heatmap):
# Unpack the dictionary into variables
for key, value in d.items():
globals()[key] = value
if event.inaxes == ax_heatmap:
x, y = event.xdata, event.ydata
y = ax_heatmap.get_ylim()[1] - y
x_idx, y_idx = abs(int(round(x))), abs(int(round(y)))
# print(x_idx, y_idx)
# Convert data coordinates to pixel coordinates
trans_data = ax_heatmap.transData
x_pixel, y_pixel = trans_data.transform((x, y))
if 0 <= x_idx < heatmap_data.shape[1] and 0 <= y_idx < heatmap_data.shape[0]: # Check bounds
for patch in ax_heatmap.patches:
patch.remove()
rect = patches.Rectangle((x_idx - 0.5, y_idx - 0.5), 1, 1, linewidth=2, edgecolor='red', facecolor='none')
ax_heatmap.add_patch(rect)
cell_value = heatmap_data[y_idx, x_idx] # Access the heatmap data directly
iArray = np.intersect1d(np.where(ch2xy[:,1] == x_idx)[0], np.where(ch2xy[:,0] == y_idx)[0])[0]
# print(f"Cell clicked: ({x_idx}, {y_idx}) with value {cell_value}") # Debugging statement
if last_button_index is not None :
update_detailed_figure(iArray, last_button_index, detailed_figure, detailed_axes, d)
is_mat = False
with open("../data/neuromosaics/osfstorage/rat1/rat1_C5_500uA.pkl" , "rb") as f:
data = pickle.load(f)
d1 = get_data(is_mat=is_mat, is_macaque=False, mat=data, name='rat1_C5_500uA')
# Unpack the dictionary into variables
for key, value in d1.items():
globals()[key] = value
# Define some plotting parameters
where_zero = np.where(abs(stimProfile) > 10**(-50))[0][0]
time = np.array([i/fs for i in range(-int(where_zero), evoked_emg.shape[2] - int(where_zero))]) * 1000
fontsize_title = 15
fontsize_axes = 13
fontsize_legend = 10
maps1 = np.copy(maps)
# PLOT
maps = maps1
current_cbar = None
# Create the main figure and axes
fig_number = len(plt.get_fignums()) + 1 # Total existing figures + 1
fig, (ax_heatmap, ax_buttons) = plt.subplots(1, 2, figsize=(9, 6), gridspec_kw={'width_ratios': [4, 1]}, num=f'Main Plot {fig_number}')
plt.subplots_adjust(left=0.1, right=0.9, wspace=0.4)
fig.suptitle(f'A. Rat motor response heatmap', fontsize=fontsize_title)
# Heatmap data generation and plotting
heatmap_data = generate_heatmap_data(0)
if not is_mat:
if additional_info['Position_Line_0'].item() == 'Left':
heatmap_data = np.rot90(heatmap_data, 1)
maps = np.rot90(maps1, 1)
if additional_info['Position_Line_0'].item() == 'Right':
heatmap_data = np.rot90(heatmap_data, 3)
maps = np.rot90(maps1, 3)
for ch in range(int(parameters['nChan'].iloc[0])):
x, y = np.where(maps == ch + 1)[0][0], np.where(maps == ch + 1)[1][0]
ch2xy[ch] = [x, y]
im = ax_heatmap.imshow((heatmap_data - np.nanmin(heatmap_data)) / (np.nanmax(heatmap_data) - np.nanmin(heatmap_data)), cmap='Blues')
divider = make_axes_locatable(ax_heatmap)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = ax_heatmap.figure.colorbar(im, cax=cax)
cbar.ax.set_ylabel("Normalized MEP (mV)", rotation=-90, va="bottom", fontsize=fontsize_axes)
x = 5
y = 1
array10 = np.intersect1d(np.where(ch2xy[:,1] == x)[0], np.where(ch2xy[:,0] == y)[0])[0]
rect = patches.Rectangle((x - 0.5, y - 0.5), 1, 1, linewidth=2, edgecolor='red', facecolor='none')
ax_heatmap.add_patch(rect)
if not is_mat:
compass_x, compass_y = 0.1, 0.9
fig.patches.extend([
FancyArrow(compass_x, compass_y, -0.03, 0, color='black', transform=fig.transFigure,
width=0.005, head_width=0.02, head_length=0.02)
])
plt.text(compass_x - 0.075, compass_y, 'Left', color='black', ha='center', va='center', transform=fig.transFigure)
fig.patches.extend([
FancyArrow(compass_x, compass_y, 0, 0.03, color='black', transform=fig.transFigure,
width=0.005, head_width=0.02, head_length=0.02)
])
plt.text(compass_x, compass_y + 0.06, 'Rostral', color='black', ha='center', va='center', transform=fig.transFigure)
ax_heatmap.set_title('Array on ' + additional_info['Segment'].item() + '\n' + emgs['emgs'][0], fontsize=fontsize_title)
ax_buttons.set_visible(True)
ax_buttons.spines['top'].set_visible(False)
ax_buttons.spines['right'].set_visible(False)
ax_buttons.spines['left'].set_visible(False)
ax_buttons.spines['bottom'].set_visible(False)
ax_buttons.xaxis.set_visible(False)
ax_buttons.yaxis.set_visible(False)
ax_buttons.set_xticks([])
ax_buttons.set_yticks([])
button_positions = [0.8 - i * 0.06 for i in range(n_muscles)]
button_axes = [plt.axes([0.85, pos, 0.1, 0.05]) for pos in button_positions]
buttons = [Button(button_axes[i], emgs['emgsabr'][i]) for i in range(n_muscles)]
square_symbol = '\u25A0'
plt.text(0.90, 0.88, "Muscle selection", fontsize=12, ha='center', transform=plt.gcf().transFigure)
plt.text(0.90, 0.2, f"{square_symbol} Selected electrode", fontsize=10, color='red', ha='center', transform=plt.gcf().transFigure)
cmap = plt.get_cmap("Blues")
deep_blue = cmap(0.8) # Dark blue for text
light_blue = cmap(0.1) # Light blue for background
plt.text(0.7, 0.02,
"Activate interactivity\n(top-right)\nThen press Play ▶ and\nclick muscles\nor the heatmap!",
fontsize=12, color=deep_blue, ha='left', fontfamily='DejaVu Sans',
bbox=dict(boxstyle="round,pad=0.2", edgecolor='none', facecolor=light_blue, alpha=0.7),
transform=plt.gcf().transFigure )
for i, button in enumerate(buttons):
button.on_clicked(lambda event, index=i: on_button_click(event, index, fig))
# Create and show the detailed figure second
detailed_fig_number = len(plt.get_fignums()) # Total existing figures + 1
detailed_figure, detailed_axes = plt.subplots(2, 2, figsize=(9, 6), num=f'Detailed View {detailed_fig_number}')
detailed_axes[1, 0].text(
0.95, 0.91, "▬ Outliers", fontsize=10, color='red',
ha='right', transform=detailed_axes[1,0].transAxes
)
plt.subplots_adjust(wspace=0.5)
detailed_axes = detailed_axes.flatten()
#detailed_figure.canvas.draw()
fig.canvas.mpl_connect('button_press_event', lambda event: on_heatmap_click(event, d1, detailed_figure, detailed_axes, ax_heatmap))
last_button_index = 0
update_detailed_figure(array10, 0, detailed_figure, detailed_axes, d1)
# Show the main figure first
#plt.show(fig)
#detailed_figure.canvas.draw()
Loading...