import numpy as np
import matplotlib.pyplot as plt

img = plt.imread('cat.jpg')
k = 50

img_gray = np.mean(img, axis=2)

U1, S1, V1 = np.linalg.svd(img_gray)
img_output = U1[:, :k] @ np.diag(S1[:k]) @ V1[:k, :]

plt.imshow(img_output / 255, cmap = 'gray')
plt.title('SVD compression')
plt.show()
