# Standard PyTorch imports
import torch
from torch import nn
from sklearn.datasets import make_circles
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split  # Split data into train and test sets
from src.formazione.utils.utilita import get_time
import torch
import matplotlib.pyplot as plt
import numpy as np

def plot_decision_boundary(model: torch.nn.Module, X: torch.Tensor, y: torch.Tensor):
    """Plots decision boundaries of model predicting on X in comparison to y.
    Source - https://madewithml.com/courses/foundations/neural-networks/ (with modifications)
    """
    # Put everything to CPU (works better with NumPy + Matplotlib)
    model.to("cpu")
    X, y = X.to("cpu"), y.to("cpu")

    # Setup prediction boundaries and grid
    x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
    y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 101), np.linspace(y_min, y_max, 101))

    # Make features
    X_to_pred_on = torch.from_numpy(np.column_stack((xx.ravel(), yy.ravel()))).float()

    # Make predictions
    model.eval()
    with torch.inference_mode():
        y_logits = model(X_to_pred_on)

    # Test for multi-class or binary and adjust logits to prediction labels
    if len(torch.unique(y)) > 2:
        y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)  # mutli-class
    else:
        y_pred = torch.round(torch.sigmoid(y_logits))  # binary

    # Reshape preds and plot
    y_pred = y_pred.reshape(xx.shape).detach().numpy()
    plt.contourf(xx, yy, y_pred, cmap=plt.cm.RdYlBu, alpha=0.7)
    plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())


n_samples = 1000
X, y = make_circles(n_samples,
                    noise=0.03,  # a little bit of noise to the dots
                    random_state=42)  # keep random state so we get the same values

# plt.scatter(x=X[:, 0],
#             y=X[:, 1],
#             c=y,
#             cmap=plt.cm.RdYlBu);

# plt.show()

# Turn data into tensors
# Otherwise this causes issues with computations later on
import torch

X = torch.from_numpy(X).type(torch.float)
y = torch.from_numpy(y).type(torch.float)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Make device agnostic code
# device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
# print(device)

# costruisco il modello
model_0 = nn.Sequential(
    nn.Linear(in_features=2, out_features=20),
    nn.ReLU(),
    # nn.Linear(in_features=10, out_features=10),
    # nn.ReLU(),
    nn.Linear(in_features=20, out_features=1)
)

# Create a loss function
# loss_fn = nn.BCELoss() # BCELoss = no sigmoid built-in
loss_fn = nn.BCEWithLogitsLoss()  # BCEWithLogitsLoss = sigmoid built-in

# Create an optimizer
optimizer = torch.optim.SGD(params=model_0.parameters(),
                            lr=0.1)


# Calculate accuracy (a classification metric)
def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()  # torch.eq() calculates where two tensors are equal
    acc = (correct / len(y_pred)) * 100
    return acc


# Make predictions with the model
epochs = 1000
# creo delle liste che conterranno i valori di loss per tenerne traccia durante le varie epche
# train_loss_values = []
# test_loss_values = []
# epoch_count = []

# untrained_preds = model_0(X_test.to(device))
# print(f"Length of predictions: {len(untrained_preds)}, Shape: {untrained_preds.shape}")
# print(f"Length of test samples: {len(y_test)}, Shape: {y_test.shape}")
# print(f"\nFirst 10 predictions:\n{untrained_preds[:10]}")
# print(f"\nFirst 10 test labels:\n{y_test[:10]}")


for epoch in range(epochs):
    ### Training

    # 0. imposto la modalità in Training (da fare ad ogni epoca)
    model_0.train()

    # 1. calcolo l'output con i parametri del modello, NB devo gare la "squeeze" percheè va ritdotta di una dimensione
    # quanto l'output del modello ne aggiunge una.
    # I logits sono i valori "grezzi" che, nella caso delle classificazioni BINARIE, NON possono essere comparati
    # con i valori discreti 0/1 delle t_test.
    # I logits quindive dobranno essere convertiti attraverso le funzioni come per la esempio la sigmoing, che
    # non fa altro che ricondurli a valori compresi tra zero e uno che, poi andranno "discretizzati" a 0/1 atttraverso
    # l'uso di funzioni di arrotondamento come per es. la round.
    y_logits = model_0(X_train).squeeze() #

    # pred. logits -> pred. probabilities -> labels 0/1
    y_pred = torch.round(torch.sigmoid(y_logits))


    # 2. calculate loss/accuracy
    # calcolo la loss, da nota che viene utilizzata come loss function la "BCEWithLogitsLoss" che vuole in input
    # dirattamente i logits anzichè i valori predetti, in quanto gli applica la sigmoid e la round in automatico
    # per poi paragonli con le y_train "discrete".
    loss = loss_fn(y_logits, y_train) # nn.BCEWithLogitsLoss()

    # calcololiamo anche la percentuale di accuratezza.
    acc = accuracy_fn(y_true=y_train, y_pred=y_pred)

    # 3. reinizializzo l'optimizer in quanto tende ad accumulare i valori
    optimizer.zero_grad()

    # 4. effettua la back propagation, nella pratica Pytorch tiene traccia dei valori associati alla discesa del gradiente
    #    Quindi calcola la derivata parziale per determinare il minimo della curva dei delta tra valori predetti e valori di test
    loss.backward()

    # 5. ottimizza i parametri (una sola volta) e in base al valore "lr".
    #  NB: cambia quindi i valori dei tensori per cercare di farli avvicinare ai valori ottimali
    optimizer.step()

    ### Testing (in questa fase vengono passati i valori non trainati di test)

    # indico a Pytrch che la fase di training è terminata e che ora devo valutare i parametri e paragonarli con i valori attesi
    model_0.eval()
    with torch.inference_mode(): # disabilito la fase di training

        test_logits = model_0(X_test).squeeze()  #

        # pred. logits -> pred. probabilities -> labels 0/1
        test_pred = torch.round(torch.sigmoid(test_logits))

        # per poi paragonli con le y_train "discrete".
        test_loss = loss_fn(test_logits, y_test)  # nn.BCEWithLogitsLoss()

        # calcololiamo anche la percentuale di accuratezza.
        test_acc = accuracy_fn(y_true=y_test, y_pred=test_pred)

        # Print out what's happening
        if epoch % 10 == 0:
            print(f"Epoch: {epoch} | Train -> Loss: {loss:.5f} , Acc: {acc:.2f}% | Test -> Loss: {test_loss:.5f}%. Acc: {test_acc:.2f}% ")


plt.figure(figsize=(12, 6))
plt.subplot(1,2,1)
plt.title("Train")
plot_decision_boundary(model_0, X_train,y_train)
plt.subplot(1,2,2)
plt.title("Test")
plot_decision_boundary(model_0, X_test,y_test)

#  visualizzo i parametri del modello
# print (model_0.state_dict())