Skip to main content

Train a model using data from apertureDB

Introduction

The following notebook trains a non pretrained model, using data from ApertureDB.

The following are the key points of the experiment:

  1. It uses cifar10 Dataset from tensorflow . This is a small DS, and suits the purpose of this notebook as the outputs can be easily understood and trinaing time is small enough.
  2. It trains a model which is based off Resnet50. This is a 50 layer Residual network which is not as deep as Resnet. This makes it suitable for this example. Read more about Residual networks
  3. It uses tensorboard to visualize the outcome of the iterations of the experiment. Specifically it uses the Tensorboard to
  • view the accuracy and loss metrics per epoch.
  • confusion matrix per epoch.
  • projects the classified outcomes in embeddings projector.

note that this is not a good candidate to see performance comparison of same dataset from local files vs that in aperturedb, as this is a very small sized dataset. This notebook is for the purpose of definfing a blueprint for how to build the neccassary tools to train a model using tensorflow, and to Visualize / debug the outcomes using tensorboard.

# Load the TensorBoard notebook extension
%reload_ext tensorboard

Common imports and definitions.

import numpy as np
import sklearn.metrics
import datetime
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import time
import warnings



import csv
from PIL import Image
import os

from keras.layers import Dense, GlobalAveragePooling2D
from keras.optimizers import RMSprop
from keras.applications import ResNet50
from keras.models import Sequential

# Helper functions to plot images and show confusion matrix
from helper import show_samples, plot_confusion_matrix,plot_to_image, plot_samples_for_tensorboard


# Reused across the notebook, to make experiments locatable.
log_prefix = "experiments/exp_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

# Label names for CIFAR10
names = ['plane','auto','bird','cat','deer','dog','frog','horse','ship','truck']

Defining the model

The model training shown in the notebook subsequently utilizes Resnet50 .

It is stacked further with a Dense layer which has 10 outputs, as it's going to be trained on Cifar10 dataset.


base_model = ResNet50(
include_top=False,
weights=None,
input_tensor=tf.keras.layers.Input(shape=(32, 32, 3)),
input_shape=None,
pooling=None,
classes=1000,
)

# Make a sequential model with ResNet50 as the first layer
# and a dense layer with 10 outputs as the last layer.
model = Sequential()
model.add(base_model)
model.add(GlobalAveragePooling2D())
model.add(Dense(10, activation="softmax", name="classifier"))
rmsprop = RMSprop(learning_rate=0.0001)

model.compile(
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
optimizer=rmsprop,
metrics=["accuracy"]
)

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
resnet50 (Functional) (None, 1, 1, 2048) 23587712

global_average_pooling2d ( (None, 2048) 0
GlobalAveragePooling2D)

classifier (Dense) (None, 10) 20490

=================================================================
Total params: 23608202 (90.06 MB)
Trainable params: 23555082 (89.86 MB)
Non-trainable params: 53120 (207.50 KB)
_________________________________________________________________

Generate confusion matrix for debuggability

def get_cm_callback(x_val, y_val):
log_dir = log_prefix + "/fit"

# Define the basic TensorBoard callback.
# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

def log_confusion_matrix(epoch, logs):
# Use the model to predict the values from the validation dataset.
test_pred_raw = model.predict(x_val)
test_pred = np.argmax(test_pred_raw, axis=1)

# Calculate the confusion matrix.
cm = sklearn.metrics.confusion_matrix(y_val, test_pred)
# Log the confusion matrix as an image summary.
figure = plot_confusion_matrix(cm, class_names=names)
cm_image = plot_to_image(figure)

# Log the confusion matrix as an image summary.
file_writer_cm = tf.summary.create_file_writer(log_dir + "/cm")

with file_writer_cm.as_default():
tf.summary.image("epoch_confusion_matrix", cm_image, step=epoch)

# Define the per-epoch callback.
cm_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)
return cm_callback

Using ApertureDB as source

from aperturedb.Images import  Images
from aperturedb.Constraints import Constraints
from aperturedb.Subscriptable import Subscriptable
from aperturedb.Utils import create_connector

import logging

import tensorflow as tf


class DS(Subscriptable):
def __init__(self, stage, batch_size, limit=50000) -> None:
super().__init__()
const = Constraints().equal("train", True if stage == "train" else False)
self.images = Images(
db = create_connector())
self.images.search(constraints=const, limit=limit)
self.label_ids = self.images.get_properties(["label"])
print(f"{len(self.images.images_ids)} images found for stage {stage}")
self.batch_size = batch_size
self.batches = len(self.images.images_ids) // self.batch_size
# if len(self.images.images_ids) % self.batch_size != 0:
# self.batches += 1

def getitem(self, subscript):
x = [None for i in range(self.batch_size)]
y = [None for i in range(self.batch_size)]

for i in range(self.batch_size):
index = subscript * self.batch_size + i
img = self.images.get_np_image_by_index(index)
image = tf.image.convert_image_dtype(img, tf.float32)
label = int(self.label_ids[self.images.images_ids[index]]["label"])
l_hot = [1 if x == label else 0 for x in range(10)]
# y.append(l_hot)
# x.append(image)
y[i] = l_hot
x[i] = image
return x, y


def __len__(self):
return self.batches

def tf_gen(stage, batch_size=64):
def wrapper():
return DS(stage, batch_size)
output_signature = (tf.TensorSpec((batch_size, 32, 32, 3), tf.float32), tf.TensorSpec((batch_size, 10), tf.int64))
return tf.data.Dataset.from_generator(wrapper, output_signature=output_signature)

ds_source = DS("train", batch_size=64)
len(ds_source)

ds_train = tf_gen("train", batch_size=64)
ds_val = tf_gen("test", batch_size=64)
50000 images found for stage train

Plot sample images from apertturedb

data = np.asanyarray(ds_source[0][0])
labels = tf.squeeze(np.apply_along_axis(lambda x: np.where(x==1), 1, np.asanyarray(ds_source[0][1]))).numpy()
show_samples(data, labels, names)

png

plot_samples_for_tensorboard(log_prefix, data)
2023-10-31 19:24:20.837692: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory

Model training


log_dir = os.path.join(log_prefix, "fit_aperturedb")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=0,
write_graph=False,
write_images=False,
update_freq="epoch",
profile_batch=0)

start = time.time()
hist = model.fit(ds_train,
epochs=5,
batch_size=64,
# verbose=1,
validation_data=ds_val,
callbacks=[tensorboard_callback, get_cm_callback(data, labels)])
print(f"training took: {time.time() - start}s")
Epoch 1/5
50000 images found for stage train


2023-10-31 19:24:54.116512: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:442] Loaded cuDNN version 8905
2023-10-31 19:24:54.161499: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2023-10-31 19:24:54.926364: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f835b0004c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-10-31 19:24:54.926377: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): NVIDIA GeForce RTX 3060, Compute Capability 8.6
2023-10-31 19:24:54.928640: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-10-31 19:24:54.961594: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.


780/Unknown - 48s 45ms/step - loss: 2.3075 - accuracy: 0.2429

2023-10-31 19:25:34.921086: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12951968619303860940


10000 images found for stage test
2/2 [==============================] - 1s 4ms/step
781/781 [==============================] - 53s 52ms/step - loss: 2.3069 - accuracy: 0.2430 - val_loss: 3.0636 - val_accuracy: 0.2938
Epoch 2/5
50000 images found for stage train
780/781 [============================>.] - ETA: 0s - loss: 1.5646 - accuracy: 0.435410000 images found for stage test
2/2 [==============================] - 0s 5ms/step
781/781 [==============================] - 42s 49ms/step - loss: 1.5644 - accuracy: 0.4354 - val_loss: 1.5898 - val_accuracy: 0.4585
Epoch 3/5
50000 images found for stage train
780/781 [============================>.] - ETA: 0s - loss: 1.2961 - accuracy: 0.536510000 images found for stage test
2/2 [==============================] - 0s 4ms/step
781/781 [==============================] - 42s 49ms/step - loss: 1.2959 - accuracy: 0.5366 - val_loss: 1.4096 - val_accuracy: 0.5062
Epoch 4/5
50000 images found for stage train
781/781 [==============================] - ETA: 0s - loss: 1.0716 - accuracy: 0.621810000 images found for stage test
2/2 [==============================] - 0s 3ms/step
781/781 [==============================] - 42s 50ms/step - loss: 1.0716 - accuracy: 0.6218 - val_loss: 1.3959 - val_accuracy: 0.5165
Epoch 5/5
50000 images found for stage train
780/781 [============================>.] - ETA: 0s - loss: 0.8420 - accuracy: 0.705110000 images found for stage test
2/2 [==============================] - 0s 4ms/step
781/781 [==============================] - 42s 50ms/step - loss: 0.8419 - accuracy: 0.7052 - val_loss: 1.8618 - val_accuracy: 0.4493
training took: 220.64626812934875s

Embedding projector

log_dir = os.path.join(log_prefix, "aperturedb_embeddings")
if not os.path.exists(log_dir):
os.makedirs(log_dir)

sample_count = 256
embeddings = tf.keras.models.Model(inputs=model.input, outputs=model.layers[-1].output)

#Populate the embedding space with the vectors and metadata
images_pil = []
images_embeddings = []
labels = []


for i in range(sample_count // 64):
idx = i % 64
ds_labels = np.apply_along_axis(lambda x: np.where(x==1), 1, np.asanyarray(ds_source[idx][1])).squeeze()
for img, label in zip(ds_source[idx][0], ds_labels):
img_pil = tf.keras.preprocessing.image.array_to_img(img)
images_pil.append(img_pil)
image_embedding = embeddings.predict(np.expand_dims(img, axis=0))[0]
images_embeddings.append(image_embedding)

label = names[label]
labels.append(label)

len(labels)

1/1 [==============================] - 1s 789ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 12ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 13ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 12ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 12ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 12ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 12ms/step
1/1 [==============================] - 0s 9ms/step





256

Simpler way of writing the same info using torch utils.

#https://github.com/tensorflow/tensorboard/issues/2471#issuecomment-636431260

from torch.utils.tensorboard import SummaryWriter
import torchvision
writer = SummaryWriter(log_dir)
writer.add_embedding(np.array(images_embeddings), metadata=labels)
writer.add_image("images", torchvision.utils.make_grid(list(map(lambda img : torchvision.transforms.functional.pil_to_tensor(img), images_pil))), 0)
writer.close()

Some screenshots from tensorboard

epoch accuracy

image

epoch loss

image

epoch confusion matrix

image

Embeddings projection after training to ~60 % in 5 epochs

image