transfer learning example with hyperparameter tuning

Transfer Learning And Hyperparameter Tuning Example

We’re going to demonstrate how to implement transfer learning and hyperparameter tuning in the following example. Furthermore, we’re going to use Xception, a pre-trained model, and create an image classifier.

Moreover, we’ll be using a dataset containing 10 different types of wild cats.

Even more, this method is appropriate to train on older GPUs, such as GeForce GTX 1050 Ti, like mine.

About the dataset

This dataset consists of 2439 images, which already come in subdirectories for training, testing and validating. They are also all the same size, which is 224 wide and 224 high with 3 color channels.

In other words, each sample has 3 dimensions, since we’re dealing with color images.

Transfer learning part of the example

First of all, we’ll need to import and preprocess our images in order for our model to accept them.

But for that, we’ll need to import all the necessary libraries, that will enable us to do so.

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import keras_tuner as kt
import matplotlib.pyplot as plt

from glob import glob
from tqdm import tqdm
from kaggle.api.kaggle_api_extended import KaggleApi

We’re also going to download the dataset by using Kaggle API, which will also enable us to do everything from this script. Therefore, we’ll need to authenticate our connection with the API and use a download function from kaggle library.

# authenticate connection with Kaggle API
api = KaggleApi()
api.authenticate()

# download dataset from https://www.kaggle.com/datasets/gpiosenka/cats-in-the-wild-image-classification
api.dataset_download_files(
    'gpiosenka/cats-in-the-wild-image-classification',
    path='datasets/wild_cats',
    unzip=True
)

Next, we need to import this dataset and store it in variables for training, testing and validating. Since this is a classification task, we also need labels, which will tell our algorithm what class each sample belongs to.

# get paths to dataset partitions
train_dir = 'datasets/wild_cats/train'
test_dir = 'datasets/wild_cats/test'
valid_dir = 'datasets/wild_cats/valid'

# get total number and names of the classes
class_names = os.listdir(train_dir)
n_classes = len(class_names)

# set parameters for our model
BATCH_SIZE = 32
IMG_SIZE = 224
AUTOTUNE = tf.data.AUTOTUNE
LEARNING_RATE = 1e-3

# set seed for replicating same results
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

# set function to import and preprocess image
def load_image(image_path):

    assert os.path.exists(image_path), f'Invalid image path: {image_path}'
    image = plt.imread(image_path)
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = tf.cast(image, tf.float32)
    image = image / 255.0

    return image

# set function for loading the dataset from folder
def load_dataset(root_path, class_names, batch_size=BATCH_SIZE, buffer_size=1000):
    n_samples = sum([len(os.listdir(os.path.join(root_path, name))) for name in class_names])
    images = np.empty(shape=(n_samples, IMG_SIZE, IMG_SIZE, 3), dtype=np.float32)
    labels = np.empty(shape=(n_samples, 1), dtype=np.int32)

    n_image = 0
    for class_name in tqdm(class_names, desc='Loading'):
        class_path = os.path.join(root_path, class_name)
        for file_path in glob(os.path.join(class_path, '*')):
            image = load_image(file_path)
            label = class_names.index(class_name)

            images[n_image] = image
            labels[n_image] = label

            n_image += 1
    
    indices = np.random.permutation(n_samples)
    images = images[indices]
    labels = labels[indices]

    return images, labels

# load images and their labels into variables
X_train, y_train = load_dataset(root_path=train_dir, class_names=class_names)
X_valid, y_valid = load_dataset(root_path=valid_dir, class_names=class_names)
X_test, y_test = load_dataset(root_path=test_dir, class_names=class_names)

Okay, now we’re ready to import Xception pre-trained model and set it up with additional layers. Furthermore, this will serve as our baseline model, which will give us its loss and accuracy.

Additionally, we’re going to use this data to compare it to the model we’ll get after we apply hyperparameter tuning on it.

# get pre-trained model and freeze its layers
xception = tf.keras.applications.Xception(input_shape=(IMG_SIZE, IMG_SIZE, 3), weights='imagenet', include_top=False)
xception.trainable = False

# set up a baseline model, compile and train it
xbaseline = tf.keras.Sequential([
    xception,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(n_classes, activation='softmax')
])

xbaseline.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    metrics=['accuracy']
)

xbaseline.fit(
    X_train, y_train, 
    validation_data=(X_valid, y_valid), 
    epochs=50, 
    callbacks=[
        tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True),
        tf.keras.callbacks.ModelCheckpoint("XceptionBaseline.h5", save_best_only=True)
    ],
    batch_size=BATCH_SIZE
)

# print out its results
xtest_loss, xtest_acc = xbaseline.evaluate(X_test, y_test)
print(f"Xception Baseline Testing Loss     : {xtest_loss}.")
print(f"Xception Baseline Testing Accuracy : {xtest_acc}.")

Following are the results we get after we run the script at this point.

Xception baseline model example results after transfer learning

Since the results from a baseline model are already really good, we might not need to apply hyperparameter tuning on it. But we’ll do it anyway to demonstrate how its done.

Furthermore, we’re going to use random search method to find best combination of hyperparameters. For this example, we’re going to look for the combination of number of fully connected layers, number of units in them and the dropout rate.

# set function for creating a model from available hyperparameters
def build_model(hp):
    
    # Define all hyperparms
    n_layers = hp.Choice('n_layers', [0, 2, 4])
    dropout_rate = hp.Choice('rate', [0.2, 0.4, 0.5, 0.7])
    n_units = hp.Choice('units', [64, 128, 256, 512])
    
    # Mode architecture
    model = tf.keras.Sequential([
        xception,
        tf.keras.layers.GlobalAveragePooling2D(),
    ])
    
    # Add hidden/top layers 
    for _ in range(n_layers):
        model.add(tf.keras.layers.Dense(n_units, activation='relu', kernel_initializer='he_normal'))
    
    # Add Dropout Layer
    model.add(tf.keras.layers.Dropout(dropout_rate))
    
    # Output Layer
    model.add(tf.keras.layers.Dense(n_classes, activation='softmax'))
    
    # Compile the model
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer = tf.keras.optimizers.Adam(LEARNING_RATE),
        metrics = ['accuracy']
    )
    
    # Return model
    return model

# Initialize Random Searcher
random_searcher = kt.RandomSearch(
    hypermodel=build_model, 
    objective='val_loss', 
    max_trials=10, 
    seed=42, 
    project_name="XceptionSearch", 
    loss='sparse_categorical_crossentropy')

# Start Searching
search = random_searcher.search(
    X_train, y_train,
    validation_data=(X_valid, y_valid),
    epochs = 10,
    batch_size = BATCH_SIZE
)

best_xception = build_model(random_searcher.get_best_hyperparameters(num_trials=1)[0])

# Model Architecture
best_xception.summary()

# Compile Model
best_xception.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE*0.1),
    metrics=['accuracy']
)

# Model Training
best_xception_history = best_xception.fit(
    X_train, y_train,
    validation_data=(X_valid, y_valid),
    epochs = 50,
    batch_size = BATCH_SIZE*2,
    callbacks = [
        tf.keras.callbacks.EarlyStopping(patience=2, restore_best_weights=True),
        tf.keras.callbacks.ModelCheckpoint("BestXception.h5", save_best_only=True)
    ]
)

best_test_loss, best_test_acc = best_xception.evaluate(X_test, y_test)
print(f"Test Loss after Tuning     : {best_test_loss} | {xtest_loss}")
print(f"Test Accuracy after Tuning : {best_test_acc}  | {xtest_acc}")

And following are the results after algorithm finishes tuning.

Conclusion

To conclude, in this example we downloaded, preprocessed image dataset and successfully trained a model using Xception as its base.

Furthermore, I hope this example helped you gain a better understanding on how to implement transfer learning and hyperparameter tuning in practice.

Share this article:

Related posts

Discussion(0)