from pathlib import Path

import cv2 as cv
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

from model import CalligraphyClassifyModel


# Global configuration
DATASET_DIR = Path("dataset")
MODEL_PT_FILE = Path("resource/model.pt")

NUM_EPOCHS = 10
BATCH_SIZE = 32

DEVICE = torch.device("cuda")


class CalligraphyDataset(Dataset):
    def __init__(self, base_dir: str | Path):
        base_dir = Path(base_dir)

        features, labels = [], []
        for label, label_name in enumerate(CalligraphyClassifyModel.LABEL_SET):
            for image_file in (base_dir / label_name).iterdir():
                features.append(cv.imread(image_file, cv.IMREAD_GRAYSCALE))
                labels.append(label)

        self.features = torch.tensor(np.array(features), dtype=torch.float, device=DEVICE)
        self.labels = torch.tensor(labels, dtype=torch.long, device=DEVICE)

    def __getitem__(self, idx: int):
        return self.features[idx], self.labels[idx]

    def __len__(self):
        return len(self.labels)


class ModelTrainer:
    def __init__(self):
        self.train_dataset = CalligraphyDataset(DATASET_DIR / "train")
        self.test_dataset = CalligraphyDataset(DATASET_DIR / "test")

        self.train_loader = DataLoader(self.train_dataset, BATCH_SIZE, shuffle=True)
        self.test_loader = DataLoader(self.test_dataset, BATCH_SIZE)

        self.model = CalligraphyClassifyModel().to(DEVICE)
        self.optimizer = torch.optim.Adam(self.model.parameters())
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def train(self, num_epochs: int = 100):
        self.model.train()

        for epoch in range(num_epochs):
            num_correct_prediction = 0
            num_total_prediction = 0

            train_loss = 0.0
            for features, labels in self.train_loader:
                self.optimizer.zero_grad()
                output = self.model(features)
                loss = self.loss_fn(output, labels)
                loss.backward()
                self.optimizer.step()

                predicted_labels = self.model.get_labels(output)

                train_loss += loss.item() * len(labels)
                num_correct_prediction += torch.count_nonzero(predicted_labels == labels)
                num_total_prediction += len(labels)

            train_accuracy = num_correct_prediction / num_total_prediction
            print(
                f"[{epoch + 1}/{num_epochs}] loss: {train_loss:.4f}, accuracy: {train_accuracy:.4f}"
            )

    def test(self):
        self.model.eval()

        test_loss = 0.0
        num_correct_prediction = 0
        num_total_prediction = 0

        with torch.no_grad():
            for features, labels in self.test_loader:
                output = self.model(features)
                loss = self.loss_fn(output, labels)

                predicted_labels = self.model.get_labels(output)

                test_loss += loss.item() * len(labels)
                num_correct_prediction += torch.count_nonzero(predicted_labels == labels)
                num_total_prediction += len(labels)

        test_accuracy = num_correct_prediction / num_total_prediction
        print(f"[test result]: loss: {test_loss:.4f}, test_accuracy: {test_accuracy:.4f}")

    def save_model(self, file_name: str | Path):
        torch.save(self.model.state_dict(), file_name)


def main():
    trainer = ModelTrainer()
    trainer.train(NUM_EPOCHS)
    trainer.test()
    trainer.save_model(MODEL_PT_FILE)


if __name__ == "__main__":
    main()
