Parkinson’s disease in the spinal cord: an exploratory study to establish T2*w, MTR and diffusion-weighted imaging metric values
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.offline import plot
import statsmodels.formula.api as smf
import base64
import os
import plotly.io as pio
pio.renderers.default = "plotly_mimetype"
# Load the CSA data
data = pd.read_csv(f'../../data/parkinsons-spinalcord-mri-metrics/data/CSA.csv')
# Spinal levels
spinal_levels = ['2', '3', '4', '5']
# Create a 4x2 subplot grid
fig = make_subplots(
rows=4, cols=2,
shared_yaxes=False,
horizontal_spacing=0.32,
vertical_spacing=0.08
)
# Plot settings
axis_title_size = 18
marker_size = 8
# Loop through spinal levels
for idx, spinal_level in enumerate(spinal_levels):
row = idx + 1
# Data per row
data['SpinalLevel'] = data['SpinalLevel'].astype(str) # Ensure "SpinalLevel" is a string
data_level = data[data['SpinalLevel'] == spinal_level]
data_CTRL = data_level[data_level['CTRL_or_PD'].str.contains('CTRL', na=False)]
data_PD = data_level[data_level['CTRL_or_PD'].str.contains('PD', na=False)]
# --- Column 1: CSA vs Age ---
for group_data, color, name in zip([data_CTRL, data_PD], ['#00517F', '#B4464F'], ['HC', 'PD']):
fig.add_trace(go.Scatter(
x=group_data['Age'],
y=group_data['CSA'],
mode='markers',
marker=dict(color=color, size=marker_size, opacity=0.8),
name=name,
legendgroup=name,
showlegend=(row == 1)
), row=row, col=1)
# Regression line
if not group_data.empty:
model = smf.ols('CSA ~ Age', data=group_data).fit()
x_pred = pd.Series(sorted(group_data['Age']))
y_pred = model.predict(pd.DataFrame({'Age': x_pred}))
fig.add_trace(go.Scatter(
x=x_pred,
y=y_pred,
mode='lines',
line=dict(color=color),
showlegend=False,
legendgroup=name,
), row=row, col=1)
#print(f"OLS (CSA ~ Age) {name}, {spinal_level}:\n{model.summary()}")
# --- Column 2: CSA vs UPDRSIII ---
fig.add_trace(go.Scatter(
x=data_PD['UPDRSIII_total'],
y=data_PD['CSA'],
mode='markers',
marker=dict(color='#B4464F', size=marker_size, opacity=0.6),
name=name,
legendgroup='PD',
showlegend=False
), row=row, col=2)
# Regression line
model = smf.ols('CSA ~ UPDRSIII_total', data=data_PD).fit()
x_pred = pd.Series(sorted(data_PD['UPDRSIII_total']))
y_pred = model.predict(pd.DataFrame({'UPDRSIII_total': x_pred}))
fig.add_trace(go.Scatter(
x=x_pred,
y=y_pred,
mode='lines',
line=dict(color=color),
showlegend=False,
legendgroup='PD',
), row=row, col=2)
#print(f"OLS (CSA ~ UPDRSIII_total) {name}, {spinal_level}:\n{model.summary()}")
# Axes styling
fig.update_yaxes(title_text='CSA', row=row, col=1, title_font=dict(size=16, family='Arial'), title_standoff=0)
fig.update_yaxes(title_text='CSA', row=row, col=2, title_font=dict(size=16, family='Arial'), title_standoff=0)
fig.update_xaxes(title_text='Age', row=row, col=1, title_font=dict(size=16, family='Arial'), title_standoff=0)
fig.update_xaxes(title_text='UPDRSIII', row=row, col=2, title_font=dict(size=16, family='Arial'), title_standoff=0)
# Add static image
with open("../templates_for_figures/suppl_figure3_template.png", "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode()
fig.add_layout_image(
dict(
source="data:image/png;base64," + encoded_image,
xref="paper",
yref="paper",
x=-0.29, # Aligns the image to the left edge of the figure
y=1.18, # Aligns the image to the top edge of the figure
sizex=1.28, # Adjust the size relative to the figure
sizey=1.28, # Adjust the size relative to the figure
xanchor="left", # Anchors the image position to the left
yanchor="top", # Anchors the image position to the top
opacity=1,
layer="below"
)
)
# Layout
fig.update_layout(
height=900,
width=900,
legend=dict(
x=-0.3,
y=1.15,
font=dict(size=16, family='Arial', color='black'),
borderwidth=0,
itemsizing='constant',
bgcolor='rgba(255,255,255, 0)', # Transparent background
),
margin=dict(t=180, l=200, r=80, b=90)
)
fig.show()
Loading...