seq2seq chatbot model example

Seq2Seq Chatbot Example using Tensorflow

In this article, we’re going to build a chatbot using a seq2seq model with attention mechanism. Furthermore, we’ll be using Python and Tensorflow machine learning library.

Seq2seq or sequence-to-sequence models gained on popularity particularly in natural language processing domain. Therefore, we can also see their applications in machine translation and image captioning.

However, in this tutorial, we’ll focus on a dialogue system, which works quite similarly to machine translation.

Generally, these type of models consist of an encoder, decoder and cross-attention mechanism. Furthermore, we’ll be creating these components with bidirectional RNN, GRU and other layers to get a decently efficient, but light model.

Seq2Seq Chatbot python code

So, without further ado, first thing we need to do is import all the necessary libraries into our script.

import pandas as pd
import numpy as np
import einops
from sklearn.model_selection import train_test_split
import tensorflow as tf
import tensorflow_text as tf_text
from kaggle.api.kaggle_api_extended import KaggleApi

Since, we’ll be downloading our dataset using Kaggle API, we also need to authenticate the connection with it.

# authenticate connection with Kaggle API
api = KaggleApi()

After that, we need to download it using a method from the API library. In addition, I created “datasets” folder inside my directory, into which I’ll be downloading the .csv file, we’ll be working with.

# download dataset from

Next, we need to set a couple of hyperparameters, which we’ll use throughout the whole process.

# set hyperparameters
VOCAB_SIZE = 10000
UNITS = 256

Preprocessing data

Now, we’re ready to start loading in our dataset and preprocess it for our model.

# load and preprocess dataset
dataset = pd.read_csv('datasets/AI.csv')

X_train, X_test, y_train, y_test = train_test_split(
    dataset['Question'], dataset['Answer'], 

train_raw =, y_train)).batch(BATCH_SIZE)
val_raw =, y_test)).batch(BATCH_SIZE)

def lower_and_split_punct(text):
    text = tf_text.normalize_utf8(text, 'NFKD')
    text = tf.strings.lower(text)
    text = tf.strings.regex_replace(text, '[^ a-z.?!,¿]', '')
    text = tf.strings.regex_replace(text, '[.?!,¿]', r' \0 ')
    text = tf.strings.strip(text)
    text = tf.strings.join([START_TOKEN, text, END_TOKEN], separator=' ')

    return text

context_text_processor = tf.keras.layers.TextVectorization(

context_text_processor.adapt( context, target: context))

target_text_processor = tf.keras.layers.TextVectorization(

target_text_processor.adapt( context, target: target))

def process_text(context, target):
    context = context_text_processor(context).to_tensor()
    target = target_text_processor(target)
    targ_in = target[:, :-1].to_tensor()
    targ_out = target[:, 1:].to_tensor()
    return (context, targ_in), targ_out

train_ds =,
val_ds =,

Here, we also created a couple of methods to preprocess the text and to put the samples together. In addition, the process_text method puts data together in a way, so our model trains to generate one word at a time.

Following is a class object that will check whether our tensor shapes match, throughout the process.

# shape checking object for debuging
class ShapeChecker():
    def __init__(self):
        self.shapes = {}

    def __call__(self, tensor, names, broadcast=False):
        if not tf.executing_eagerly():
        parsed = einops.parse_shape(tensor, names)

        for name, new_dim in parsed.items():
            old_dim = self.shapes.get(name, None)

            if(broadcast and new_dim == 1):

            if old_dim is None:
                self.shapes[name] = new_dim

            if new_dim != old_dim:
                raise ValueError(
                    f'Shape mismatch for dimension: {name}\n'
                    f'found: {new_dim}\n'
                    f'expected: {old_dim}\n'

Seq2Seq chatbot components

Now, we’re ready to start building our seq2seq model components. In order for our chatbot to learn to answer a question, we need to create an encoder class, which will take a question as an input.

Furthermore, we’ll build it using a bidirectional RNN architecture with gated recurrent units (GRUs).

# define encoder class
class Encoder(tf.keras.layers.Layer):
    def __init__(self, text_processor, units):
        super(Encoder, self).__init__()
        self.text_processor = text_processor
        self.vocab_size = text_processor.vocabulary_size()
        self.units = units
        self.embedding = tf.keras.layers.Embedding(self.vocab_size, units, mask_zero=True)
        self.rnn = tf.keras.layers.Bidirectional(
            merge_mode = 'sum',
            layer = tf.keras.layers.GRU(
                return_sequences = True,
                recurrent_initializer = 'glorot_uniform'
    def call(self, x):
        shape_checker = ShapeChecker()
        shape_checker(x, 'batch s')

        x = self.embedding(x)
        shape_checker(x, 'batch s units')

        x = self.rnn(x)
        shape_checker(x, 'batch s units')

        return x
    def convert_input(self, texts):
        texts = tf.convert_to_tensor(texts)
        if len(texts.shape) == 0:
            texts = tf.convert_to_tensor(texts)[tf.newaxis]
        context = self.text_processor(texts).to_tensor()
        context = self(context)

        return context

Next step is to define the cross-attention layer, which will act as a bridge between the encoder and decoder.

# define cross-attention class
class CrossAttention(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        self.mha = tf.keras.layers.MultiHeadAttention(key_dim=units, num_heads=1, **kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()

    def call(self, x, context):
        shape_checker = ShapeChecker()

        shape_checker(x, 'batch t units')
        shape_checker(context, 'batch s units')

        attn_output, attn_scores = self.mha(

        shape_checker(x, 'batch t units')
        shape_checker(attn_scores, 'batch heads t s')

        attn_scores = tf.reduce_mean(attn_scores, axis=1)
        shape_checker(attn_scores, 'batch t s')
        self.last_attention_weights = attn_scores

        x = self.add([x, attn_output])
        x = self.layernorm(x)

        return x

Okay, now we’re ready to define the decoder class, which is basically the workhorse of this chatbot system. To explain, decoder will train to generate answers according to the context vector comming from encoder.

Furthermore, this will enable the seq2seq chatbot model to get a context understanding from question-answer pairs.

# define decoder class
class Decoder(tf.keras.layers.Layer):
    def __init__(self, text_processor, units):
        super(Decoder, self).__init__()
        self.text_processor = text_processor
        self.vocab_size = text_processor.vocabulary_size()

        self.word_to_id = tf.keras.layers.StringLookup(
            vocabulary = text_processor.get_vocabulary(),
            mask_token = '',
            oov_token = OOV_TOKEN

        self.id_to_word = tf.keras.layers.StringLookup(
            vocabulary = text_processor.get_vocabulary(),
            mask_token = '',
            oov_token = OOV_TOKEN,
            invert = True

        self.start_token = self.word_to_id(START_TOKEN)
        self.end_token = self.word_to_id(END_TOKEN)

        self.units = units

        self.embedding = tf.keras.layers.Embedding(
            mask_zero = True

        self.rnn = tf.keras.layers.GRU(
            return_sequences = True,
            return_state = True,
            recurrent_initializer = 'glorot_uniform'

        self.attention = CrossAttention(units)
        self.output_layer = tf.keras.layers.Dense(self.vocab_size)

    def call(self, context, x, state=None, return_state=False):
        shape_checker = ShapeChecker()
        shape_checker(x, 'batch t')
        shape_checker(context, 'batch s units')

        x = self.embedding(x)
        shape_checker(x, 'batch t units')

        x, state = self.rnn(x, initial_state=state)
        shape_checker(x, 'batch t units')

        x = self.attention(x, context)
        self.last_attention_weights = self.attention.last_attention_weights
        shape_checker(x, 'batch t units')
        shape_checker(self.last_attention_weights, 'batch t s')

        logits = self.output_layer(x)
        shape_checker(logits, 'batch t target_vocab_size')

        if return_state:
            return logits, state
            return logits
    def get_initial_state(self, context):
        batch_size = tf.shape(context)[0]
        start_tokens = tf.fill([batch_size, 1], self.start_token)
        done = tf.zeros([batch_size, 1], dtype=tf.bool)
        embedded = self.embedding(start_tokens)
        return start_tokens, done, self.rnn.get_initial_state(embedded)[0]
    def tokens_to_text(self, tokens):
        words = self.id_to_word(tokens)
        result = tf.strings.reduce_join(words, axis=-1, separator=' ')
        result = tf.strings.regex_replace(result, '^ *\[START\] *', '')
        result = tf.strings.regex_replace(result, ' *\[END\] *$', '')

        return result
    def get_next_token(self, context, next_token, done, state, temperature=0.0):
        logits, state = self(
            context, next_token,

        if temperature == 0.0:
            next_token = tf.argmax(logits, axis=-1)
            logits = logits[:, -1, :] / temperature
            next_token = tf.random.categorical(logits, num_samples=1)
        done = done | (next_token == self.end_token)
        next_token = tf.where(done, tf.constant(0, dtype=tf.int64), next_token)

        return next_token, done, state

Okay, now it’s time to put all these components together.

# put it all together in Chatbot class 
class Chatbot(tf.keras.Model):
    def __init__(self, units, context_text_processor, target_text_processor):
        encoder = Encoder(context_text_processor, units)
        decoder = Decoder(target_text_processor, units)

        self.encoder = encoder
        self.decoder = decoder

    def call(self, inputs):
        context, x = inputs
        context = self.encoder(context)
        logits = self.decoder(context, x)

            del logits._keras_mask
        except AttributeError:

        return logits

    def prompt(self, texts, *, max_length=50, temperature=0.0):
        context = self.encoder.convert_input(texts)
        batch_size = tf.shape(texts)[0]

        tokens = []
        attention_weights = []
        next_token, done, state = self.decoder.get_initial_state(context)

        for _ in range(max_length):
            next_token, done, state = self.decoder.get_next_token(
                context, next_token, done, state, temperature)


            if tf.executing_eagerly() and tf.reduce_all(done):
        tokens = tf.concat(tokens, axis=-1)
        self.last_attention_weights = tf.concat(attention_weights, axis=1)

        result = self.decoder.tokens_to_text(tokens)

        return result

Before, we start training our model, we’re also going to define custom masked loss and accuracy functions.

# custom loss and accuracy functions
def masked_loss(y_true, y_pred):
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none'

    loss = loss_fn(y_true, y_pred)
    mask = tf.cast(y_true != 0, loss.dtype)
    loss *= mask

    return tf.reduce_sum(loss)/tf.reduce_sum(mask)

def masked_acc(y_true, y_pred):
    y_pred = tf.argmax(y_pred, axis=-1)
    y_pred = tf.cast(y_pred, y_true.dtype)

    match = tf.cast(y_true == y_pred, tf.float32)
    mask = tf.cast(y_true != 0, tf.float32)

    return tf.reduce_sum(match)/tf.reduce_sum(mask)

Now, we have all we need to start the training process.

# create, compile, and train model
model = Chatbot(UNITS, context_text_processor, target_text_processor)

    metrics=[masked_acc, masked_loss]

vocab_size = 1.0 * target_text_processor.vocabulary_size()
model.evaluate(val_ds, steps=20, return_dict=True)

history =
    callbacks=[tf.keras.callbacks.EarlyStopping(patience=3, monitor='masked_loss')]

Save the model

In order to be able to actually use the model, we need to save it. Thus, we’ll need to define an export class and use a Tensorflow function to save it.

Furthermore, this will create a folder with the name you set inside the save function, in which our model will be stored.

# for saving the model after training
class Export(tf.Module):
    def __init__(self, model):
        self.model = model
    @tf.function(input_signature=[tf.TensorSpec(dtype=tf.string, shape=[None])])
    def prompt(self, inputs):
        return self.model.prompt(inputs)

# save the model
export = Export(model), 'chatbot', signatures={'serving_default': export.prompt})

Use the seq2seq chatbot model

In order to use the model we just trained, we’ll need to open up a new script. Here we’ll import the necessary libraries and create an infinite loop.

import tensorflow as tf
import tensorflow_text as tf_text
import numpy as np

model = tf.saved_model.load('chatbot')

while True:
    prompt = input('You: ')
    print('Bot:', model.prompt([prompt])[0].numpy().decode())

Following image shows the final result of the models capabilities.

seq2seq chatbot results


To conclude, we built a seq2seq model for a chatbot and trained it on AI related questions and answers.

I hope you gained a better understanding how these types of machine learning models work and are able to implement it with your own custom data.

Share this article:

Related posts