Spaces:
Runtime error
Runtime error
from matplotlib import pyplot as plt | |
import numpy as np | |
import streamlit as st | |
import pandas as pd | |
from utils import getSquareYVectorised, getCircle, getBatman, transform, plotGridLines, discriminant | |
minv = -5.0 | |
maxv = 5.0 | |
step = 0.1 | |
np.set_printoptions(precision=3) | |
xlim = (-10,10) | |
ylim = (-10,10) | |
st.title("Visualizing Eigenvectors with 2x2 Linear Transformations") | |
st.write( | |
"This app shows the effect of a 2x2 linear transformation on simple shapes to understand the role of eigenvectors and eigenvalues in quantifying the nature of a transformation.") | |
with st.sidebar: | |
data = st.selectbox('Select type of dataset', ['Square', 'Circle', 'Batman']) | |
if data == 'Batman': | |
black = st.checkbox(label='Black') | |
transform_type = st.selectbox('Select type of transformation', ['Custom', 'Stretch', 'Shear', 'Rotate']) | |
st.write("---") | |
if transform_type == 'Custom': | |
st.markdown("Select elements of transformation matrix $A$") | |
a_00 = st.slider(label = '$a_{00}$', min_value = minv, max_value=maxv, value=1.0, step=step) | |
a_01 = st.slider(label = '$a_{01}$', min_value = minv, max_value=maxv, value=0.0, step=step) | |
a_10 = st.slider(label = '$a_{10}$', min_value = minv, max_value=maxv, value=0.0, step=step) | |
a_11 = st.slider(label = '$a_{11}$', min_value = minv, max_value=maxv, value=1.0, step=step) | |
t = np.array([[a_00, a_01], [a_10, a_11]], dtype=np.float64) | |
elif transform_type == 'Stretch': | |
both = st.checkbox('Set equal') | |
if not both: | |
stretch_x = st.slider(label = 'Stretch in x-direction', min_value = minv, max_value=maxv, value=1.0, step=step) | |
stretch_y = st.slider(label = 'Stretch in y-direction', min_value = minv, max_value=maxv, value=1.0, step=step) | |
t = np.array([[stretch_x, 0], [0, stretch_y]], dtype=np.float64) | |
else: | |
stretch = st.slider(label = 'Scale', min_value = minv, max_value=maxv, value=1.0, step=step) | |
t = np.array([[stretch, 0], [0, stretch]], dtype=np.float64) | |
elif transform_type == 'Shear': | |
left, right = st.columns(2) | |
with left: | |
both = st.checkbox('Set equal') | |
if not both: | |
shear_x = st.slider(label = 'Shear in x-direction', min_value=minv, max_value=maxv, value=0.0, step=step) | |
shear_y = st.slider(label = 'Shear in y-direction', min_value=minv, max_value=maxv, value=0.0, step=step) | |
t = np.array([[1, shear_x], [shear_y, 1]], dtype=np.float64) | |
else: | |
with right: | |
sign = st.checkbox('Opposite sign') | |
shear = st.slider(label = 'Shear in both directions', min_value=minv, max_value=maxv, value=0.0, step=step) | |
t = np.array([[1, -shear], [shear, 1]], dtype=np.float64) if sign else np.array([[1, shear], [shear, 1]], dtype=np.float64) | |
else: | |
st.markdown("Rotate by $\\theta$ in anti-clockwise\ndirection") | |
min_theta = -180.0 | |
max_theta = 180.0 | |
theta = st.slider(label = '$\\theta$', min_value=min_theta, max_value=max_theta, value=0.0, step=step, format="%f°") | |
rtheta = np.pi * theta/180.0 | |
t = np.array([[np.cos(rtheta), -np.sin(rtheta)], [np.sin(rtheta), np.cos(rtheta)]], dtype=np.float64) | |
st.write("---") | |
st.write("The transformation matrix A is:") | |
st.table(pd.DataFrame(t)) | |
st.write("---") | |
showNormalSpace = st.checkbox(label= 'Show original space (without transform)', value=False) | |
if data == 'Square': | |
x = np.linspace(-1,1,1000) | |
y = getSquareYVectorised(x) | |
elif data == 'Circle': | |
x = np.linspace(-1,1,1000) | |
y = getCircle(x) | |
else: | |
X, Y = getBatman(s=2) | |
if data != 'Batman': | |
x_dash_up, y_dash_up = transform(x,y,t) | |
x_dash_down, y_dash_down = transform(x,-y,t) | |
else: | |
tmp = [transform(x, y, t) for x, y in zip(X, Y)] | |
X_dash = [t[0] for t in tmp] | |
Y_dash = [t[1] for t in tmp] | |
evl, evec = np.linalg.eig(t) | |
fig, ax = plt.subplots() | |
if showNormalSpace: | |
if data != 'Batman': | |
ax.plot(x, y, 'r', alpha=0.5) | |
ax.plot(x, -y, 'g', alpha=0.5) | |
else: | |
for i, (x, y) in enumerate(zip(X, Y)): | |
if black: | |
ax.plot(x, y, 'k-', alpha=0.5, linewidth=1) | |
elif i < 3: | |
ax.plot(x, y, 'g-', alpha=0.5, linewidth=1) | |
else: | |
ax.plot(x, y, 'r-', alpha=0.5, linewidth=1) | |
if not np.iscomplex(evec).any(): | |
ax.quiver(0,0,evec[0,0],evec[1,0],scale=1,scale_units ='xy',angles='xy', facecolor='black', alpha=0.5) | |
ax.quiver(0,0,evec[0,1],evec[1,1],scale=1,scale_units ='xy',angles='xy', facecolor='black', alpha=0.5) | |
plotGridLines(xlim,ylim,np.array([[1,0], [0,1]]),'#9D9D9D','Normal Space',0.4) | |
if data != 'Batman': | |
ax.plot(x_dash_up,y_dash_up,'r') | |
ax.plot(x_dash_down,y_dash_down, 'g') | |
else: | |
for i, (x, y) in enumerate(zip(X_dash, Y_dash)): | |
if black: | |
ax.plot(x, y, 'k-', linewidth=1) | |
elif i < 3: | |
ax.plot(x, y, 'g', linewidth=1) | |
else: | |
ax.plot(x, y, 'r', linewidth=1) | |
if not (np.iscomplex(evl).any() or np.iscomplex(evec).any()): | |
ax.quiver(0,0,evec[0,0]*evl[0],evec[1,0]*evl[0],scale=1,scale_units ='xy',angles='xy', facecolor='cyan', label='$eigen\ vector_{\lambda_0}$') | |
ax.quiver(0,0,evec[0,1]*evl[1],evec[1,1]*evl[1],scale=1,scale_units ='xy',angles='xy', facecolor='blue', label='$eigen\ vector_{\lambda_1}$') | |
plotGridLines(xlim,ylim,t,'#403B3B','Transformed space',0.6) | |
ax.text(11,3,'|A|={:.2f}'.format(np.linalg.det(t)), fontdict={'fontsize':11}) | |
ax.text(11,2,'D = {:.2f}'.format(discriminant(t)), fontdict={'fontsize':11}) | |
if discriminant(t) < 0: | |
ax.text(13,1,'Negative!'.format(discriminant(t)), fontdict={'fontsize':8}) | |
ax.set_xlim(*xlim) | |
ax.set_ylim(*ylim) | |
ax.set_aspect('equal', adjustable='box') | |
ax.xaxis.set_tick_params(labelbottom=False) | |
ax.yaxis.set_tick_params(labelleft=False) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
fig.legend(bbox_to_anchor=(1.05, 0.86), loc=1, borderaxespad=0., fontsize=8) | |
st.pyplot(fig) | |
df = pd.DataFrame({'Eigenvalues': evl, 'Eigenvectors': [str(evec[:,0]), str(evec[:,1])],\ | |
'Transformed Eigenvectors': [str(evec[:,0]*evl[0]), str(evec[:,1]*evl[1])]}) | |
st.table(df.style.format({'Eigenvalues':'{:.2f}'})) | |
if np.iscomplex(evl).any() or np.iscomplex(evec).any(): | |
st.write("Due to complex eigenvectors and eigenvalues, the transformed eigenvectors are not\ | |
displayed...") | |
file = open("description.md", "r") | |
st.markdown(file.read()) |