Train a model using data from apertureDB
Introduction
The following notebook trains a non pretrained model, using data from ApertureDB.
The following are the key points of the experiment:
- It uses cifar10 Dataset from tensorflow . This is a small DS, and suits the purpose of this notebook as the outputs can be easily understood and trinaing time is small enough.
- It trains a model which is based off Resnet50. This is a 50 layer Residual network which is not as deep as Resnet. This makes it suitable for this example. Read more about Residual networks
- It uses tensorboard to visualize the outcome of the iterations of the experiment. Specifically it uses the Tensorboard to
- view the accuracy and loss metrics per epoch.
- confusion matrix per epoch.
- projects the classified outcomes in embeddings projector.
note that this is not a good candidate to see performance comparison of same dataset from local files vs that in aperturedb, as this is a very small sized dataset. This notebook is for the purpose of definfing a blueprint for how to build the neccassary tools to train a model using tensorflow, and to Visualize / debug the outcomes using tensorboard.
# Load the TensorBoard notebook extension
%reload_ext tensorboard
Common imports and definitions.
import numpy as np
import sklearn.metrics
import datetime
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import time
import warnings
import csv
from PIL import Image
import os
from keras.layers import Dense, GlobalAveragePooling2D
from keras.optimizers import RMSprop
from keras.applications import ResNet50
from keras.models import Sequential
# Helper functions to plot images and show confusion matrix
from helper import show_samples, plot_confusion_matrix,plot_to_image, plot_samples_for_tensorboard
# Reused across the notebook, to make experiments locatable.
log_prefix = "experiments/exp_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# Label names for CIFAR10
names = ['plane','auto','bird','cat','deer','dog','frog','horse','ship','truck']
Defining the model
The model training shown in the notebook subsequently utilizes Resnet50 .
It is stacked further with a Dense layer which has 10 outputs, as it's going to be trained on Cifar10 dataset.
base_model = ResNet50(
include_top=False,
weights=None,
input_tensor=tf.keras.layers.Input(shape=(32, 32, 3)),
input_shape=None,
pooling=None,
classes=1000,
)
# Make a sequential model with ResNet50 as the first layer
# and a dense layer with 10 outputs as the last layer.
model = Sequential()
model.add(base_model)
model.add(GlobalAveragePooling2D())
model.add(Dense(10, activation="softmax", name="classifier"))
rmsprop = RMSprop(learning_rate=0.0001)
model.compile(
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
optimizer=rmsprop,
metrics=["accuracy"]
)
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
resnet50 (Functional) (None, 1, 1, 2048) 23587712
global_average_pooling2d ( (None, 2048) 0
GlobalAveragePooling2D)
classifier (Dense) (None, 10) 20490
=================================================================
Total params: 23608202 (90.06 MB)
Trainable params: 23555082 (89.86 MB)
Non-trainable params: 53120 (207.50 KB)
_________________________________________________________________
Generate confusion matrix for debuggability
def get_cm_callback(x_val, y_val):
log_dir = log_prefix + "/fit"
# Define the basic TensorBoard callback.
# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
def log_confusion_matrix(epoch, logs):
# Use the model to predict the values from the validation dataset.
test_pred_raw = model.predict(x_val)
test_pred = np.argmax(test_pred_raw, axis=1)
# Calculate the confusion matrix.
cm = sklearn.metrics.confusion_matrix(y_val, test_pred)
# Log the confusion matrix as an image summary.
figure = plot_confusion_matrix(cm, class_names=names)
cm_image = plot_to_image(figure)
# Log the confusion matrix as an image summary.
file_writer_cm = tf.summary.create_file_writer(log_dir + "/cm")
with file_writer_cm.as_default():
tf.summary.image("epoch_confusion_matrix", cm_image, step=epoch)
# Define the per-epoch callback.
cm_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)
return cm_callback
Using ApertureDB as source
from aperturedb.Images import Images
from aperturedb.Constraints import Constraints
from aperturedb.Subscriptable import Subscriptable
from aperturedb.Utils import create_connector
import logging
import tensorflow as tf
class DS(Subscriptable):
def __init__(self, stage, batch_size, limit=50000) -> None:
super().__init__()
const = Constraints().equal("train", True if stage == "train" else False)
self.images = Images(
db = create_connector())
self.images.search(constraints=const, limit=limit)
self.label_ids = self.images.get_properties(["label"])
print(f"{len(self.images.images_ids)} images found for stage {stage}")
self.batch_size = batch_size
self.batches = len(self.images.images_ids) // self.batch_size
# if len(self.images.images_ids) % self.batch_size != 0:
# self.batches += 1
def getitem(self, subscript):
x = [None for i in range(self.batch_size)]
y = [None for i in range(self.batch_size)]
for i in range(self.batch_size):
index = subscript * self.batch_size + i
img = self.images.get_np_image_by_index(index)
image = tf.image.convert_image_dtype(img, tf.float32)
label = int(self.label_ids[self.images.images_ids[index]]["label"])
l_hot = [1 if x == label else 0 for x in range(10)]
# y.append(l_hot)
# x.append(image)
y[i] = l_hot
x[i] = image
return x, y
def __len__(self):
return self.batches
def tf_gen(stage, batch_size=64):
def wrapper():
return DS(stage, batch_size)
output_signature = (tf.TensorSpec((batch_size, 32, 32, 3), tf.float32), tf.TensorSpec((batch_size, 10), tf.int64))
return tf.data.Dataset.from_generator(wrapper, output_signature=output_signature)
ds_source = DS("train", batch_size=64)
len(ds_source)
ds_train = tf_gen("train", batch_size=64)
ds_val = tf_gen("test", batch_size=64)
50000 images found for stage train
Plot sample images from apertturedb
data = np.asanyarray(ds_source[0][0])
labels = tf.squeeze(np.apply_along_axis(lambda x: np.where(x==1), 1, np.asanyarray(ds_source[0][1]))).numpy()
show_samples(data, labels, names)
plot_samples_for_tensorboard(log_prefix, data)
2023-10-31 19:24:20.837692: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
Model training
log_dir = os.path.join(log_prefix, "fit_aperturedb")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=0,
write_graph=False,
write_images=False,
update_freq="epoch",
profile_batch=0)
start = time.time()
hist = model.fit(ds_train,
epochs=5,
batch_size=64,
# verbose=1,
validation_data=ds_val,
callbacks=[tensorboard_callback, get_cm_callback(data, labels)])
print(f"training took: {time.time() - start}s")
Epoch 1/5
50000 images found for stage train
2023-10-31 19:24:54.116512: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:442] Loaded cuDNN version 8905
2023-10-31 19:24:54.161499: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2023-10-31 19:24:54.926364: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f835b0004c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-10-31 19:24:54.926377: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): NVIDIA GeForce RTX 3060, Compute Capability 8.6
2023-10-31 19:24:54.928640: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-10-31 19:24:54.961594: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
780/Unknown - 48s 45ms/step - loss: 2.3075 - accuracy: 0.2429
2023-10-31 19:25:34.921086: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12951968619303860940
10000 images found for stage test
2/2 [==============================] - 1s 4ms/step
781/781 [==============================] - 53s 52ms/step - loss: 2.3069 - accuracy: 0.2430 - val_loss: 3.0636 - val_accuracy: 0.2938
Epoch 2/5
50000 images found for stage train
780/781 [============================>.] - ETA: 0s - loss: 1.5646 - accuracy: 0.435410000 images found for stage test
2/2 [==============================] - 0s 5ms/step
781/781 [==============================] - 42s 49ms/step - loss: 1.5644 - accuracy: 0.4354 - val_loss: 1.5898 - val_accuracy: 0.4585
Epoch 3/5
50000 images found for stage train
780/781 [============================>.] - ETA: 0s - loss: 1.2961 - accuracy: 0.536510000 images found for stage test
2/2 [==============================] - 0s 4ms/step
781/781 [==============================] - 42s 49ms/step - loss: 1.2959 - accuracy: 0.5366 - val_loss: 1.4096 - val_accuracy: 0.5062
Epoch 4/5
50000 images found for stage train
781/781 [==============================] - ETA: 0s - loss: 1.0716 - accuracy: 0.621810000 images found for stage test
2/2 [==============================] - 0s 3ms/step
781/781 [==============================] - 42s 50ms/step - loss: 1.0716 - accuracy: 0.6218 - val_loss: 1.3959 - val_accuracy: 0.5165
Epoch 5/5
50000 images found for stage train
780/781 [============================>.] - ETA: 0s - loss: 0.8420 - accuracy: 0.705110000 images found for stage test
2/2 [==============================] - 0s 4ms/step
781/781 [==============================] - 42s 50ms/step - loss: 0.8419 - accuracy: 0.7052 - val_loss: 1.8618 - val_accuracy: 0.4493
training took: 220.64626812934875s
Embedding projector
log_dir = os.path.join(log_prefix, "aperturedb_embeddings")
if not os.path.exists(log_dir):
os.makedirs(log_dir)
sample_count = 256
embeddings = tf.keras.models.Model(inputs=model.input, outputs=model.layers[-1].output)
#Populate the embedding space with the vectors and metadata
images_pil = []
images_embeddings = []
labels = []
for i in range(sample_count // 64):
idx = i % 64
ds_labels = np.apply_along_axis(lambda x: np.where(x==1), 1, np.asanyarray(ds_source[idx][1])).squeeze()
for img, label in zip(ds_source[idx][0], ds_labels):
img_pil = tf.keras.preprocessing.image.array_to_img(img)
images_pil.append(img_pil)
image_embedding = embeddings.predict(np.expand_dims(img, axis=0))[0]
images_embeddings.append(image_embedding)
label = names[label]
labels.append(label)
len(labels)
1/1 [==============================] - 1s 789ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 12ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 13ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 12ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 12ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 12ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 12ms/step
1/1 [==============================] - 0s 9ms/step
256
Simpler way of writing the same info using torch utils.
#https://github.com/tensorflow/tensorboard/issues/2471#issuecomment-636431260
from torch.utils.tensorboard import SummaryWriter
import torchvision
writer = SummaryWriter(log_dir)
writer.add_embedding(np.array(images_embeddings), metadata=labels)
writer.add_image("images", torchvision.utils.make_grid(list(map(lambda img : torchvision.transforms.functional.pil_to_tensor(img), images_pil))), 0)
writer.close()