AI vs. Real Artwork Vision Transformer Classifier

In this code, I built a project aimed at classifying images to distinguish between AI-generated and real artwork using a Vision Transformer (ViT) model.

Here’s a breakdown of everything I did

Model Selection and Configuration: I started by choosing the Vision Transformer model (vit_base_patch16_224_in21k) from the timm library, which is known for its strong performance on image classification tasks. This choice was crucial because Vision Transformers are particularly effective for handling image data and understanding complex patterns.

Data Preparation: The images were stored in two directories: one for AI-generated artwork and another for real artwork. I combined the paths and labeled them (0 for AI-generated and 1 for real). I employed StratifiedKFold to split the data into three folds to ensure balanced representation of classes in each fold. This helps in avoiding overfitting and ensures the model generalizes well.

Augmentation and Dataset Class: To enhance the model’s robustness, I applied data augmentation using the albumentations library. Augmentation techniques like flipping, scaling, and rotation were used to simulate different variations of the images. I created a custom Dataset class (ArtDataset) to load the images and apply these augmentations during training.

Model Architecture: I designed the VITModel class, where I used a pre-trained Vision Transformer model and modified its final layer to output a single value, which indicates the likelihood of the image being real or AI-generated. Training Process: I implemented a Trainer class to handle the training and validation processes. The model was trained using BCEWithLogitsLoss, which is suitable for binary classification tasks. For each epoch, the model was trained on the training set, and its performance was validated on the validation set. The validation metric used was the ROC-AUC score, which is a common metric for binary classification tasks. After each epoch, the model’s state was saved, allowing for the use of these trained models later.

Ensemble and Prediction: In the ensemble.ipynb, I loaded the models trained from each fold and used them to predict on new images. By averaging the predictions from each fold, I aimed to improve the overall prediction accuracy (this technique is known as ensembling). I included functions to display predictions alongside the original image, which is helpful for visualizing the results.

Deployment via Streamlit: Finally, I set up a Streamlit application (app.py) to allow users to upload images and see predictions in real-time. The app loads the trained models, processes the uploaded image, and displays whether the image is predicted to be AI-generated or real.