ByteByteGo logo
ByteByteGo logo
04

ChatGPT: Personal Assistant Chatbot

Introduction

ChatGPT [1] is a chatbot that was developed by OpenAI and launched in 2022. It generates human-like text based on the input it receives. The chatbot can assist with various tasks, including answering questions, providing explanations, and producing creative content.

ChatGPT has quickly become one of the most adopted applications in history. It gained over 100 million users in less than three months after its launch [2]. This rapid growth highlights the capabilities of generative AI and its potential to help with day-to-day tasks and enhance productivity. In this chapter, we examine the key components of building a chatbot similar to ChatGPT.

Image represents a simulated ChatGPT interaction.  At the top, 'ChatGPT' is labeled, suggesting a user interface for interacting with the large language model.  Below this, a user input box contains the prompt: 'write a short message so I apologize my manage...'.  Two response bubbles, each preceded by a small, stylized circular logo, appear below. The first response bubble contains the text 'Sorry for the inconvenience.' and the second contains 'Sorry for my mistak...'.  To the right, another response bubble shows a user's follow-up message: 'inconvenience? lol'. At the bottom, a button labeled 'Message ChatGPT' is shown with an upward-pointing arrow next to it, indicating the direction of message submission to the ChatGPT model.  The overall arrangement suggests a conversational flow, with the user's prompt initiating the interaction, followed by ChatGPT's responses and a subsequent user interaction.
Figure 1: Example of a conversation with ChatGPT

Clarifying Requirements

Here is a typical interaction between a candidate and an interviewer:

Candidate: What languages should the chatbot support?
Interviewer: Initially, let’s focus on English.

Candidate: We need to ensure the chatbot generates unbiased and safe outputs by having strict content moderation and algorithms in place. Is that a fair assumption?
Interviewer: Certainly.

Candidate: Can you specify the range of tasks the chatbot is expected to perform?
Interviewer: The chatbot should be able to handle tasks like providing information and answering questions.

Candidate: Does the chatbot take in or output non-text modalities, such as image, audio, or video?
Interviewer: Let’s focus on a text-only chatbot for now. The input and the output are both text.

Candidate: Should the chatbot handle follow-up questions? How long should the chatbot maintain the conversation context?
Interviewer: Good question. The chatbot should be able to handle follow-up questions within the same conversation session. Let’s say it is expected to have a context window of at least 4096 tokens.

Candidate: Is the chatbot expected to be able to browse websites, call external API, or search online?
Interviewer: Let’s not focus on that in this round.

Candidate: Should the chatbot personalize its interactions with users?
Interviewer: Let’s not focus on personalization.

Candidate: Do we have instruction-based training data?
Interviewer: Yes, we have a dataset with 80,000 examples of instructions and answers.

Frame the Problem as an ML Task

Specifying the system’s input and output

The input to the chatbot is a text prompt provided by the user. The prompt can be a question, a command, or any other form of textual query. The output is a relevant and contextually appropriate response generated by the chatbot.

Image represents a simple data flow diagram illustrating a user interaction with a chatbot.  The diagram shows a left-to-right flow.  On the left, the user input 'where bill gates...' is depicted.  A black arrow points from this input to a rectangular box representing the 'Chatbot,' which is colored light orange with a golden-yellow border.  The chatbot processes the input.  Another black arrow then points from the chatbot box to the output on the right, which displays the response 'Bill Gates was born in Seattle,...'.  The text 'Text is not SVG - cannot display' is present at the bottom of the chatbot box, indicating a technical note about the image's creation. The overall structure demonstrates a basic question-answer interaction where a user query is fed into a chatbot, which then generates a relevant response.
Figure 2: Input and output of a chatbot

Choosing a suitable ML approach

Developing a chatbot is a language generation task in which a language model processes an input prompt and generates a response. This language model typically needs billions of parameters to learn effectively; hence, they are often called large language models (LLMs).

As we saw in previous chapters, the decoder-only Transformer is the standard architectural choice in language models. Most modern LLMs, such as OpenAI’s GPT [3], Google’s Gemini [4], and Meta's Llama [5], are all based on the decoder-only Transformer. In line with these models, we use a decoder-only

Data Preparation

The effectiveness of an LLM depends on the quality of its training data, which mainly comes from web sources. This data, often automatically crawled from websites, forums, and blogs, requires special considerations and careful preparation. The most common steps include:

  • Content extraction and parsing: Web-crawled data often contains extraneous elements, for example, HTML tags, advertisements, and navigation links. This step involves parsing the raw HTML content using libraries such as Beautiful Soup [6] or lxml [7] and extracting the main text body while discarding irrelevant sections. Techniques such as DOM analysis [8] and boilerplate detection [9] are applied to isolate and retain the core content relevant to language modeling.
  • URL and domain filtering: Not all web domains provide high-quality or relevant content. URL filtering uses predefined rules or machine learning (ML) classifiers to exclude unwanted sources, for example, low-quality blogs, content farms, or spam sites. Domain allowlisting or blocklisting techniques are also applied to curate data from trusted and relevant sources, ensuring the dataset’s quality and reliability.
  • Language identification: Crawled data often includes multilingual content, which needs to be filtered to match the target language(s) for training. Language detection tools such as fastText [10] or langid.py [11] are employed to classify and filter documents.
  • Content quality filtering: Not all web content is equally valuable for training purposes. Quality assessment techniques, including readability scoring, spam detection algorithms, and heuristic checks (e.g., content length, sentence structure), are used to evaluate and filter low-quality text. ML models can also be employed to predict the quality of web content based on features extracted from the text. This step is crucial to ensure that only high-quality data is used for training.

Along with these techniques, the following are commonly applied to both web-crawled data and other data sources, such as books, articles, and social media posts:

  • Remove inappropriate content: We use ML models to remove offensive, harmful, or NSFW content from training data. This ensures our model learns only appropriate and safe content.
  • Anonymize sensitive information: We anonymize any personally identifiable information (PII) in the dataset. This step is crucial to comply with privacy laws and ethical guidelines.
  • Remove low-quality data: We use ML models to remove low-quality text data. The models assess coherence, relevance, grammar, and readability. This step ensures the training data is high in quality and useful for training.
  • Remove duplicated data (Deduplication): We remove similar texts that might be present in different sources. For example, if the same news article is collected from multiple websites, we identify and retain only one copy. This step reduces redundancy in the training dataset and ensures the model is not overexposed to certain data.
  • Remove irrelevant data: We use heuristics and rule-based methods to remove irrelevant data. For example, we remove texts with non-standard characters or in languages that the chatbot is not expected to support.
  • Tokenize text: We use a subword tokenization algorithm such as Byte-Pair Encoding (BPE) to tokenize the text data. To review BPE, refer to Chapter 3.

Model Development

Architecture

The LLM’s architecture is based on a decoder-only Transformer. While the text embedding, Transformer blocks, and the prediction head are similar to the decoder-only Transformer discussed in Chapter 2, LLMs typically use more advanced positional encoding methods.

Image represents a diagram of a text generation model architecture.  At the bottom is a 'Text Embedding' layer, which presumably converts input text into numerical vectors. Above this is a 'Positional Encoding' layer, adding positional information to the embedded vectors.  The core of the model is a 'Transformer' block, which consists of a vertically stacked sequence of layers. This sequence is repeated Nx times, indicated by a curly brace and the label 'Nx'. Each repetition within the Transformer includes a 'Normalization' layer, a 'Feed Forward' layer, and another 'Normalization' layer, followed by a 'Multi-head...' layer (the ellipsis suggests further details omitted from the diagram). Finally, at the top is a 'Prediction Head' layer, responsible for generating the output text based on the processed information from the Transformer.  The data flow is bottom-up: text is embedded, positional information is added, then the data passes through the repeated Transformer layers, and finally, the prediction head generates the output.
Figure 3: Components of a decoder-only Transformer

Let’s examine LLM’s positional encoding in more detail.

Positional encoding

In a chatbot setting, the input sequence is typically much longer than a single sentence or email. As per the interviewer’s requirements, our goal is to build a system with a context window of at least 4096 tokens. This requires a positional encoding method that allows the model to understand the positions of all the tokens and the relationships between them.

In this section, we begin with a brief review of absolute positional encoding followed by an exploration of relative positional encoding. Finally, we delve into rotary positional embedding (RoPE) [12], a robust positional encoding method used by popular LLMs such as Llama 3 [13].

Absolute positional encoding

Absolute positional encoding refers to traditional methods such as sinusoidal or learnable encodings whereby each position in a sequence is represented by a unique vector.

In this approach, encoded positions are then added to the token embeddings, providing the model with information about where each token appears in the sequence. Formally, the attention keys and queries are computed using the following equations:

qm=Wq(em+pm)kn=Wk(en+pn)\begin{aligned} q_m & =W_q\left(e_m+p_m\right) \\ k_n & =W_k\left(e_n+p_n\right) \end{aligned}

Where:

  • qmq_m is the query vector at position mm,
  • knk_n is the key vector at position nn,
  • WqW_q and WkW_k are learnable weight matrices,
  • eme_m and ene_n are token embeddings at positions mm and nn,
  • pmp_m and pnp_n are positional vectors (either learnable or fixed) at positions mm and nn.

The attention score is calculated as a dot product of the query and key vectors:

qmkn=emWqWken+emWqWkpn+pmWqWken+pmWqWkpnq_m \cdot k_n=e_m W_q W_k e_n+e_m W_q W_k p_n+p_m W_q W_k e_n+p_m W_q W_k p_n

Notice that positional encodings pmp_m and pnp_n depend only on absolute positions. Therefore, this approach captures only information about absolute position, limiting the model's ability to capture relative distances between tokens and generalize to sequences with different lengths or unseen token positions. For example, a model trained on sequences of up to 512 tokens may struggle to generalize when applied to sequences with 4096 tokens. The sinusoidal patterns tend to become repetitive over long distances, resulting in a loss of information about token relationships. This shortcoming is addressed by relative positional encoding.

Relative positional encoding

In relative positional encoding, instead of encoding the absolute positions of tokens, we encode the differences in the positions of two tokens. This way, the model can focus on the relative distances between tokens, which is often more important than their absolute positions. For example, in a sentence, knowing that the word "car" follows the word "chased" is more informative than knowing the positions of the tokens as numbers 5 and 10.

The attention calculation in relative positional encoding can be expressed in different ways. The T5 paper [14] suggests that the second and the third interaction terms in the original expression in absolute positional encoding can be dropped and that the fourth term could be replaced by a learnable bias:

qmkn=emWqWken+bm,nq_m \cdot k_n=e_m W_q W_k e_n+b_{m, n}

In contrast, the DeBerta paper [15] drops the last term and replaces the second and third terms, which consist of the absolute positional vectors pmp_m and pnp_n, respectively, with the relative positional vector RnmR_{n-m} :

qmkn=emWqWken+emWqWkRnm+RnmWqWkenq_m \cdot k_n=e_m W_q W_k e_n+e_m W_q W_k R_{n-m}+R_{n-m} W_q W_k e_n

Relative positional encoding allows the model to understand the relationships between tokens independently of their absolute positions. However, it introduces additional complexity because qmknq_m \cdot k_n in the attention mechanism can no longer be reduced to a simple dot product. This limits our ability to use efficient techniques such as linear attention [16]. RoPE, which we examine next, addresses this limitation by encoding both absolute and relative positional information through a rotation in the embedding space.

Rotary positional encoding (RoPE)

RoPE represents positional information as a rotation matrix applied to the token embeddings.
This can be described mathematically as follows: Given an input sequence, RoPE applies a rotation matrix to each embedding. This transformation can be expressed as:

f(qm,m)=qmR(θm)f\left(q_m, m\right)=q_m \cdot R\left(\theta_m\right)

where qmq_m is token embedding at the position mm, and R(θm)R\left(\theta_m\right) is a rotation matrix parameterized by the positional angle θm\theta_m. This angle is typically derived from the position index mm and is constructed in such a way that the rotation captures both the absolute and relative position information. This rotation matrix, constructed using trigonometric functions, rotates the embeddings in the complex plane, capturing both absolute and relative positional information.

Image represents a comparison of two 2D vector representations.  The left side shows a Cartesian coordinate system (x and y axes) with two vectors originating from the origin (0,0). A reddish-brown vector labeled 'cat' points upward and to the right, and a blue-gray vector labeled 'dog' points upward and to the right at a smaller angle than the 'cat' vector.  A curved arrow indicates an angle θ between the two vectors. Below the graph, the text 'The cat chased the dog' is written. The right side mirrors the structure, also showing a Cartesian coordinate system with two vectors originating from the origin.  However, the 'cat' vector (reddish-brown) points upward and to the left, while the 'dog' vector (blue-gray) points upward and to the right.  The angle θ between these vectors is also indicated by a curved arrow. Below this graph, the text 'Once upon a time, the cat ch...' is partially visible, suggesting a narrative context.  The overall image uses vector diagrams to potentially illustrate different interpretations or scenarios related to the phrase 'the cat chased the dog,' highlighting the change in relative vector directions and the angle between them.
Figure 4: RoPE in 2D

Figure 4 shows how rotational position encoding works by rotating word embeddings in a two-dimensional space. The words "cat" and "dog" are represented as vectors, and the angle between them, denoted by θ\theta , encodes their positional relationship. On the left, the sentence is “The cat chased the dog.” The position of "cat" is shown in red, and the position of "dog" is shown in blue. The angle between these vectors captures the relative positioning of these two words in the sentence.

On the right, another sentence “Once upon a time, the cat chased the dog” is shown. Notice that the relative angle between the "cat" and "dog" vectors is still the same, but their absolute positions are different. This demonstrates how RoPE captures both the absolute and relative positions of words, allowing the model to understand the order and distance between words in a sentence.

In a higher-dimensional space, the rotation matrix can be extended to accommodate dd dimensions:

Rθ,md=(cos(mθ1)sin(mθ1)0000sin(mθ1)cos(mθ1)000000cos(mθ2)sin(mθ2)0000sin(mθ2)cos(mθ2)000000cos(mθd/2)sin(mθd/2)0000sin(mθd/2)cos(mθd/2)) R_{\theta, m}^d=\left(\begin{array}{ccccccc} \cos \left(m \theta_1\right) & -\sin \left(m \theta_1\right) & 0 & 0 & \cdots & 0 & 0 \\ \sin \left(m \theta_1\right) & \cos \left(m \theta_1\right) & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos \left(m \theta_2\right) & -\sin \left(m \theta_2\right) & \cdots & 0 & 0 \\ 0 & 0 & \sin \left(m \theta_2\right) & \cos \left(m \theta_2\right) & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos \left(m \theta_{d / 2}\right) & -\sin \left(m \theta_{d / 2}\right) \\ 0 & 0 & 0 & 0 & \cdots & \sin \left(m \theta_{d / 2}\right) & \cos \left(m \theta_{d / 2}\right) \end{array}\right)
Figure 5: Rotation matrix Rθ,mdR_{\theta, m}^d with dd dimension parameterized by positional angle θ\theta
Pros:
  • Translational invariance: RoPE encodes positional information in a way that remains consistent even when the positions of tokens shift. This helps the model handle changes in position better than other methods.
  • Relative position representation: RoPE’s rotations encode positional information geometrically within the embedding space. This enables the model to inherently understand the relative distances between tokens, unlike traditional sinusoidal encodings, which encode position additively without leveraging this geometric insight.
  • Generalization to unseen positions: Because RoPE encodes position through rotations, the resulting embeddings maintain consistent relationships, regardless of absolute position. This allows for better generalization across varying sequence lengths.
Cons:
  • Mathematical complexity: RoPE introduces additional mathematical operations involving rotations in the embedding space. While these are not overly complex, they are more intricate than traditional positional encoding methods such as sinusoidal or learned positional embeddings.

Training

In earlier chapters, we explored a two-stage strategy for training language models. However, those stages are not sufficient when training advanced chatbots. Most chatbots, including ChatGPT, use a three-stage training strategy:

  • Pretraining
  • Supervised finetuning (SFT)
  • Reinforcement learning from human feedback (RLHF)

Let’s discuss each stage in more detail to understand its purpose.

Image represents a three-stage process for training a chatbot.  Three cylindrical databases labeled 'General...', 'Instruction...', and 'Human Feedba...' (presumably representing general data, instruction data, and human feedback data, respectively) feed into three sequential processing stages.  The first stage, '1. Pretraining,' uses the 'General...' data to create a 'Base Model' represented as a light gray cloud.  The output 'Base Model' is then fed, along with data from the 'Instruction...' database, into the second stage, '2. SFT' (Supervised Fine-Tuning), which produces an 'SFT Model,' also represented as a light gray cloud.  Finally, the 'SFT Model' and data from the 'Human Feedba...' database are input into the third stage, '3. RLHF' (Reinforcement Learning from Human Feedback), resulting in a final 'Chatbot...' output, depicted as a light green cloud.  Arrows clearly indicate the data flow between each component, showing a linear progression from raw data to a refined chatbot model.
Figure 6: Three stages of training an LLM

1. Pretraining

Pretraining is the initial stage of the training process. In this stage, a model is trained with an enormous volume of text data from the Internet. The purpose of pretraining is to create a base model with a broad understanding of language and world knowledge.

The pretraining stage requires significant computational resources. It typically requires thousands of GPUs, costs millions of dollars, and takes months of training.

Pretraining data

The pretraining data typically consists of a large corpus of general text data from various sources on the Internet, for example, web pages, books, and social media posts.

Several datasets are commonly used when pretraining LLMs. Each serves a unique purpose, from broadening the model's exposure to diverse language styles to deepening its understanding of specific domains. Commonly used datasets are:

  • Common Crawl: Common Crawl [17] is a publicly available dataset collected from a large number of web pages on the internet. It contains petabytes of data that have been regularly collected since 2008. This data often includes irrelevant information and harmful content; hence, significant data cleaning is needed to make it suitable for training LLMs.
  • C4: C4 [18], created by Google, is a cleaned version of the Common Crawl dataset specifically used for training LLMs.
  • GitHub: The GitHub dataset comprises a vast collection of open-source code repositories. Its purpose is to help the model understand programming languages and code structures.
  • Wikipedia: The Wikipedia dataset includes a wide range of factual information extracted from Wikipedia. This dataset is generally a more reliable source because it is written and edited more carefully.
  • Books: The books dataset is a collection of books across various genres. Books have long textual content and good data quality, both of which contribute to improving the performance of LLMs.
  • ArXiv: The Arxiv dataset contains academic materials and published papers. This dataset helps the model understand the terminology and knowledge within the academic domain.
  • Stack Exchange: Stack Exchange [19] is a website of high-quality questions and answers, primarily in the format of a dialogue between users. Most popular LLMs are trained on all or some of the listed datasets. For example, Meta’s Llama-1 model uses all the above datasets, consisting of about 1.4 trillion tokens. Table 1 shows the proportion of each dataset used by Llama-1 during training.
DatasetSampling proportionDisk size
Common Crawl67.0%3.3 TB
C415.0%783 GB
Github4.5%328 GB
Books4.5%85 GB
Wikipedia4.5%83 GB
ArXiv2.5%92 GB
Stack Exchange2.0%78 GB

Table 1: Llama 1 pretraining dataset

ML objective and loss function

As we are training a decoder-only Transformer for text generation, we use the standard next-token prediction as our ML objective. For the loss function, we employ cross-entropy loss to measure the difference between the predicted token probabilities and correct tokens.

The outcome of the pretraining stage

The pretraining stage produces a model that understands language well. This model, usually referred to as the base model, predicts text to follow the given input prompt, generating relevant and meaningful text.

Image represents a simple, vertically oriented diagram enclosed within a dashed-line border.  At the bottom, the text 'I want to learn programming' is positioned. An upward-pointing arrow connects this text to a light orange, rectangular box labeled 'Base Model' in the center.  Another upward-pointing arrow extends from the top of the 'Base Model' box to the text 'because it is a valuable skill' at the top of the diagram. The arrows visually represent a causal relationship, suggesting that the desire to learn programming ('I want to learn programming') leads to engaging with a 'Base Model,' which in turn is motivated by the perceived value of programming as a skill ('because it is a valuable skill').  The overall structure is minimalistic, focusing on the core relationship between motivation, a foundational model, and the ultimate goal.
Figure 7: Base model continuing a sentence

While the base model understands language well, it is only capable of continuing on from the text prompt. To make the model a useful chatbot that answers questions, we need to further train the base model. This leads to the next stage: supervised finetuning.

2. Supervised finetuning (SFT)

SFT, also named instruction finetuning, is the second stage of the training process. In this stage, we finetune the base model on a smaller, high-quality dataset in a (prompt, response) format. The purpose of this stage is to preserve the base model’s language understanding and world knowledge while adapting its behavior to responding to prompts instead of just continuing them.

Training data

The training data for the SFT stage follows the (prompt, response) format. This data is usually called demonstration data because it demonstrates to the model how to respond to prompts.

Image represents a simple, rectangular box with rounded corners, representing a system's input and output.  The top section, labeled 'Prompt...', is a blank space intended for user input, likely text-based, to initiate a process or query.  Below this, a larger blank space labeled 'Response...' is provided for the system's output, also presumably text-based, in response to the prompt.  There are no visible connections or arrows indicating information flow between the prompt and response areas; the implication is that the input in the 'Prompt...' area triggers a process within the unseen system, resulting in the output displayed in the 'Response...' area.  At the very bottom, in small text, is the note 'text is not SVG - cannot display,' indicating a limitation in displaying the image's underlying format.
Figure 8: An example of demonstration data

The main differences between demonstration data and pretraining data, aside from format, are size and quality.

Size: Demonstration data is significantly smaller than pretraining data. It usually ranges from 10,000 to 100,000 (prompt, response) pairs. Table 2 shows the data sizes of popular demonstration datasets.

DatasetSizeNotes
InstructGPT [20]~14,500OpenAI’s GPT-3 instruction datasets
Alpaca [21]52,000Developed by Stanford researchers
Dolly-15K [22]~15,000Created by Databricks
FLAN 2022 [23]~104,000Developed by Google Research

Table 2: Common demonstration datasets

Quality: Demonstration data is of higher quality compared to pretraining data. The data is usually created by educated human contractors. In specialized industries such as healthcare or finance, it is essential to hire domain experts to ensure the accuracy and relevance of data. For instance, as shown in Table 3, over one-third of OpenAI’s labelers for GPT’s demonstration dataset held a master’s degree [20]. While this requirement is costly, it's crucial for producing reliable, industry-specific responses.

EducationPercentage
Less than a high school degree0%
High school degree10.5%
Undergraduate degree52.6%
Master’s degree36.8%

Table 3: The education levels of OpenAI’s labelers

ML objective and loss function

Although the training data differs from the pretraining stage, the model still learns a similar task: generating a text one token at a time based on the input prompt. Therefore, the ML objective and loss functions remain similar to those at the pretraining stage: next-token prediction ML objective and cross-entropy loss function.

Image represents a simplified diagram of a sequence-to-sequence model, likely used in a machine translation or text generation task.  At the bottom, the input 'What is the capital...' is fed into a rectangular 'Model' component, which represents the core neural network.  The model processes this input and outputs 'Predicted token...', visualized as four vertically stacked rectangles representing token embeddings.  Above this, 'Correct tokens' ('It is Paris.') are shown as a similar set of four vertically stacked rectangles.  Arrows indicate the flow of information: the input feeds into the model, the model produces predicted tokens, and these predicted tokens are compared to the correct tokens.  The comparison is quantified by 'cross-entropy loss,' which is calculated and represented by an upward-pointing arrow from the predicted tokens to the correct tokens.  The entire system is enclosed within a dashed box.  The rectangles represent token embeddings, with each rectangle likely containing a vector representation of a word or sub-word unit.
Figure 9: Loss calculation over a (prompt, response) example
The outcome of the SFT stage

The outcome of this stage is the SFT model, a finetuned version of the base model. Instead of merely continuing the text prompt, the SFT model generates detailed and helpful responses because it has been trained on demonstration data in a (prompt, response) format.

Image represents a simple flowchart enclosed within a dashed-line box.  At the bottom, the text 'I want to learn programming' indicates the starting point or goal. An upward-pointing arrow connects this text to a light orange, rectangular box labeled 'SFT Model,' representing a specific model or system.  Another upward-pointing arrow connects the 'SFT Model' box to the text 'Start with Python' at the top, suggesting that Python is the recommended or suggested programming language to interact with or utilize the 'SFT Model.' The overall structure depicts a linear flow, implying that to achieve the goal of learning programming ('I want to learn programming'), one should utilize the 'SFT Model' and start with Python.
Figure 10: The SFT model answers a prompt instead of continuing it

The SFT model usually generates a grammatically correct and reasonable response. However, it might not always generate the best response; its answers can be unhelpful or even unsafe. Figure 11 displays four plausible responses to a question. Only the second response is both safe and helpful. The first and fourth responses are grammatically and contextually correct but do not offer accurate advice. The third response is helpful but impolite.

Image represents a simple flowchart illustrating a generative AI model's response to a user prompt.  A central rectangular box labeled 'PromptWhat are some effective w...' represents the user's input, a question likely seeking effective ways to manage stress or similar.  From this central box, four curved arrows point to four separate rectangular boxes, each labeled 'Response 1,' 'Response 2,' 'Response 3,' and 'Response 4,' respectively. Each response box contains a different suggestion: Response 1 suggests 'Go skydiving for an adrenaline...'; Response 2 suggests 'Exercise regularly and maintain...'; Response 3 offers the more critical 'Shame on you! Try meditation!!!'; and Response 4 provides the somewhat dismissive 'Ignore your problems and hope t...'. The arrows visually depict the flow of information, showing how the AI model generates multiple diverse responses based on a single user prompt.
Figure 11: Various responses to a prompt

To ensure the model produces relevant, safe, and helpful responses, we must further finetune the model. This further finetuning is the primary focus of the next stage: RLHF.

3. RLHF

RLHF, also known as the alignment stage, is the final stage in the training process. This stage aligns the model to human preferences, that is, it adapts the model to generate responses preferred by humans.

To understand RLHF, let's briefly revisit the SFT stage. In SFT, the model learns from demonstration data to produce a plausible response to a given prompt. However, demonstration data provides the model with one plausible response for a prompt, which is not necessarily the most helpful or relevant response. Usually, multiple responses can be plausible, and some will be more relevant than others, as shown in Figure 11.

If we have a separate reward model that can score how relevant a model’s response is to a prompt, we can further finetune the SFT model to generate not just any plausible response but one with a high score. This is the key idea behind RLHF. RLHF consists of two sequential steps:

  1. Training a reward model
  2. Optimizing the SFT model

3.1 Training a reward model

The first step in RLHF is to train a reward model that evaluates the relevance of a response to a prompt. This model takes a (prompt, response) pair as input and produces a score predicting the helpfulness of the response. The higher the score, the more helpful the response is expected to be. Figure 12 illustrates the predicted scores for different (prompt, response) pairs.

Image represents a flowchart illustrating a reward model's scoring mechanism for two different responses to the same prompt.  The flowchart is divided into two sections by a dashed line, representing Example 1 and Example 2. Each section begins with a 'Prompt' box containing the text 'What are some effective w...', followed by a 'Response' box.  Example 1's response is 'Exercise regularly and maintain...', while Example 2's response is 'Shame on you! Try meditation!!!'.  Each 'Response' box is connected to a 'Reward Model' box (represented in light purple).  The 'Reward Model' processes the response and outputs a 'Score' to a circular box.  Example 1 receives a score of '5', while Example 2 receives a score of '1', indicating that the reward model assigns higher scores to responses deemed more effective or appropriate based on some underlying criteria not explicitly shown in the diagram.  The arrows clearly show the flow of information from the prompt and response to the reward model and finally to the score.
Figure 12: Reward model input and output
Reward model architecture

Training a model to output a score is a very common task in ML. There are various architectures we can employ for reward modeling: It can be a decoder-only, encoder-only, or encoder-decoder Transformer as long as it outputs a scalar value.

Based on public studies, there isn't a consistent pattern of reward models being larger or smaller than the language models they are used to train. For example, OpenAI uses a 6B reward model for the 175B language model [20]. Anthropic uses language models and reward models ranging from 10B to 52B parameters [24]. A typical option is to create a copy of the SFT model and add a prediction head to produce the relevance score for the given (prompt, response) pair.

Image represents a system for evaluating the quality of responses generated by a language model.  At the bottom, a 'Prompt' ('What is 2+2?') and a 'Response' ('Math is hard.') are input into a 'Reward Model...'. This model processes the prompt and response, producing multiple intermediate representations (shown as vertically stacked rectangles), which are then aggregated.  These aggregated representations are fed into a 'Prediction He...' module, which generates a prediction.  Finally, a 'score' (0.3 in this example) is output, representing the quality of the response as assessed by the system.  The ellipsis (...) indicates that multiple intermediate representations are generated by the Reward Model, suggesting a process involving multiple steps or features for evaluating the response.  The arrows depict the flow of information between the components.
Figure 13: Reward model architecture
Training data

To collect training data for reward modeling, we follow these steps:

  1. Collect prompts: Manually create a list of prompts.
  2. Generate multiple responses: Use the SFT model to generate multiple responses for each prompt.
  3. Rank responses: Ask contractors to evaluate those responses and rank them based on their relevance. The reason they typically rank them instead of scoring each response is that ranking reduces subjectivity and inconsistency. It is easier and more intuitive for annotators to compare responses directly than to assign numerical scores, which can vary between annotators. This approach simplifies the evaluation process and ensures more reliable data for training.
  4. Create preference pairs: Construct the training dataset by forming pairs in the format (prompt, winning response, losing response). In each pair, the winning response is preferred over the losing response based on the rankings from the previous step.

Figure 14 shows the process of collecting training data to train the reward model.

Image represents a system for evaluating responses from a Large Language Model (LLM).  The process begins (1) with a prompt list containing questions like 'What is the capital of France?', 'Name a famous physicist?', 'What's 2 + 2?', and 'Give a synonym for 'happy.''.  These prompts are fed (2) into an 'SFT Model' (likely a fine-tuned large language model), which generates three different responses (Response 1, Response 2, Response 3) for each prompt.  A human evaluator (3) then reviews these responses and determines a 'winning' and 'losing' response for each prompt, based on accuracy and quality. This information is compiled (4) into a table showing the winning and losing responses for each prompt. Finally, the evaluator provides rankings (R1 > R2 > R3, for example) indicating the relative quality of the three responses for each prompt, providing feedback on the LLM's performance.  The entire diagram illustrates a human-in-the-loop evaluation process for ranking the quality of LLM responses.
Figure 14: Collecting training data to train a reward model

Once we collect the training data, where each example is in (prompt, winning response, losing response) format, we define the ML objective and the associated loss function to train our reward model.

ML objective and loss function

The reward model aims to predict a higher score for the winning response compared to the losing response. More formally, for a given (prompt, winning response, losing response), the ML objective is to maximize Swin Slose S_{\text {win }}-S_{\text {lose }}, where:

  • Swin S_{\text {win }} is the predicted score for the (prompt, winning response) pair
  • Slose S_{\text {lose }} is the predicted score for the (prompt, losing response) pair

To achieve this ML objective, we need a loss function that penalizes the model when the difference between the winning and losing scores is too small. A commonly used loss function for this purpose is the margin ranking loss. The loss function is defined as:

L(Swin ,Slose )=max(0,m(Swin Slose ))\mathcal{L}\left(S_{\text {win }}, S_{\text {lose }}\right)=\max \left(0, m-\left(S_{\text {win }}-S_{\text {lose }}\right)\right)

Where mm is a hyperparameter defining the margin. This margin indicates the minimum desired difference between the scores of the winning and losing responses. If the difference between SwinS_{\text {win}} and SloseS_{\text {lose}} is less than mm, the optimizer will update the model parameters so that either SwinS_{\text {win}} increases or SloseS_{\text {lose}} decreases.

Image represents a simplified reward model for a generative AI system.  A central, peach-colored rectangle labeled 'Reward Model' receives input from three sources: 'Prompt,' displaying the text 'What is 2+2?'; 'Winning res...', showing the correct answer 'Four.'; and 'Losing resp...', displaying the incorrect answer 'Math is hard.'  Arrows indicate the flow of information into the Reward Model.  The Reward Model then outputs two values, represented by  '$S_{win}' and '$S_{los}', connected by a dashed line labeled 'Margin...', suggesting a comparison or difference between the winning and losing reward signals. Arrows point from the Reward Model to these output values, indicating that the model assigns different reward signals based on the correctness of the response.  The entire diagram illustrates how the system evaluates responses based on a prompt and assigns rewards accordingly.
Figure 15: Reward modeling loss calculation for a single example from training data
The outcome of reward modeling

The outcome of this step is a reward model that predicts relevance scores for (prompt, response) pairs. These scores reflect human judgments and are crucial for the second step in RLHF.

Image represents a simplified diagram of a reinforcement learning system, specifically focusing on the reward mechanism.  The diagram shows a rectangular box labeled 'Reward Model' in peach/light-orange with a golden border, representing the core component that evaluates the quality of a generated response.  Below this box, '(Prompt, Response)' indicates that the input to the Reward Model consists of a prompt and the corresponding generated response. A single upward-pointing arrow connects '(Prompt, Response)' to the 'Reward Model,' signifying the flow of data.  Another upward-pointing arrow connects the 'Reward Model' to the word 'Score' at the top, indicating that the Reward Model outputs a numerical score based on its evaluation of the prompt and response pair.  The overall structure illustrates a process where a prompt and response are fed into a Reward Model, which then generates a score reflecting the quality of the response.
Figure 16: The reward model predicts the relevance score for a (prompt, response) pair

3.2. Optimizing the SFT model

In the second step of RLHF, the SFT model is optimized with the help of the reward model. The purpose of this step is to adapt the SFT model to generate responses that are not only plausible but also helpful, based on the scores from the reward model.

A common approach to optimize the SFT model is to employ a reinforcement learning (RL) algorithm such as proximal policy optimization (PPO) [25], where the SFT model is finetuned to maximize the scores predicted by the reward model. This finetuning process performs the following steps iteratively:

  1. Generate model responses: The model generates multiple possible responses for a given prompt.
  2. Compute rewards: The reward model scores these responses.
  3. Update model weights: The RL algorithm updates model weights to maximize the expected reward. This step reinforces responses that receive higher scores from the reward model.

Figure 17 shows this process for a single response. In practice, multiple responses are generated and evaluated simultaneously.

Image represents a simplified reinforcement learning (RL) system diagram.  A light orange rectangle labeled 'RL Model' is the core component, receiving input from an unspecified source indicated by 'What is...'. The RL model processes this input and produces an output, represented by a downward arrow and the text 'Math is hard.', which signifies the model's computation. This output feeds into a light green rectangle labeled 'Reward...', representing the reward signal generated by the model's actions.  A score of '1.8' is shown, likely indicating the performance metric of the RL model.  This reward signal is then fed into a purple rectangle labeled 'PPO,' which stands for Proximal Policy Optimization, an optimization algorithm.  The PPO algorithm uses the reward to optimize the RL model, indicated by an arrow labeled 'Optimize' connecting PPO back to the RL Model.  The system forms a closed loop where the PPO algorithm improves the RL model based on the reward signal, creating a continuous optimization process.
Figure 17: Optimizing the model with the PPO algorithm
Training data

For this step, the training data usually includes a list of prompts created by contractors, typically ranging in number from 10,000 to 100,000.

ML objective and loss function

Well-known LLMs such as ChatGPT and Llama use RL algorithms such as PPO and direct policy optimization (DPO) [26]. However, the details of these algorithms are usually beyond the scope of most ML system design interviews. For more information, refer to [27] and [28].

The outcome of RLHF

The outcome of the RLHF stage is usually the final model that can be deployed as a chatbot. Table 4 lists some of the most popular LLMs.

LLM nameDeveloperRelease dateAccessParameters
o1OpenAISeptember 12, 2024Preview onlyUnknown
GPT-4oOpenAIMay 13, 2024APIUnknown
Claude 3AnthropicMarch 14, 2024APIUnknown
Gemini 1.5DeepMindFebruary 2, 2024APIUnknown
Llama 3Meta AIApril 18, 2024Open-Source8 and 70 billion
Grok-1xAINovember 4, 2023Open-Source314 billion
Mixtral 8x22BMistral AIApril 10, 2024Open-Source141 billion
GemmaDeepMindFebruary 21, 2024Open-Source2 and 7 billion
Phi-3MicrosoftApril 23, 2024Open-Source3.8 billion
DBRXDatabricksMarch 27, 2024Open-Source132 billion

Table 4: Popular LLMs

To summarize the training section, we employ a three-stage training strategy, including pretraining, SFT, and RLHF. Pretraining involves training a model on a large corpus of text to gain a broad language understanding. SFT finetunes the model to adapt its output to a (prompt, response) format. RLHF further refines the model's responses to be helpful, safe, and aligned with human preferences.

Image represents a flowchart illustrating the stages involved in training a large language model (LLM), specifically highlighting the progression from a base model to a refined model using reinforcement learning.  The flowchart is divided into four main columns representing distinct training stages: Pretraining, Supervised Finetuning, Reward Modeling, and Reinforcement Learning. Each stage consists of three rows detailing the dataset used, the computational resources employed (number of GPUs), and the algorithm applied.  The Pretraining stage uses internet data, thousands of GPUs, and a language modeling algorithm to create a 'Base Model' (examples: GPT, Llama, PaLM).  The Supervised Finetuning stage takes the Base Model as input, utilizes demonstration data and 1-100 GPUs with a language modeling algorithm to produce an 'SFT Model' (example: Vicuna-13B).  The Reward Modeling stage uses comparisons data, 1-100 GPUs, and a regression algorithm to create a 'Reward Model' based on the SFT Model. Finally, the Reinforcement Learning stage uses prompts, 1-100 GPUs, and a reinforcement learning algorithm, taking both the SFT Model and the Reward Model as input, to generate an 'RL Model' (examples: ChatGPT, Gemini). Arrows indicate the flow of information and the model's progression through each stage.  Each stage's input and output are clearly labeled, along with the computational resources and algorithms used.
Figure 18: Summary of LLM training, inspired by [29]

Sampling

In LLMs, sampling refers to how we select tokens from the model's predicted probability distribution to generate coherent and helpful responses.

Image represents a simplified illustration of a Large Language Model (LLM) generating text.  At the bottom, a sequence of tokens 'What is 2 + 2 ?' is fed as input to the LLM (represented by a peach-colored rectangle). The LLM processes this input and outputs a probability distribution for the next token. This distribution is shown as a vertical column of numbers next to the LLM, with probabilities 0.01, 0.01, 0.93, and 0.00 assigned to different tokens.  A curved arrow connects this probability column to a histogram labeled 'Token probabilit...', which visually represents the same probability distribution. The histogram shows that the token 'four' has a 93% probability, '<EOS>' (end of sequence) has 1%, 'able' has a small probability, and 'zebra' has a very small probability.  An upward arrow connects the highest probability token, 'four', to the output of the LLM, indicating that the LLM selects 'four' as the next word in the sequence based on its highest probability.
Figure 19: Selecting the next token from predicted probabilities

As discussed in Chapter 2, there are various methods for generating text. Some are deterministic, while others are stochastic. In this section, we examine these methods to determine which works better for open-ended text generation.

Image represents a hierarchical tree diagram illustrating different text generation methods.  At the top is a rectangular box labeled 'Text Generation Methods,' which branches down into two main categories: 'Deterministic' and 'Stochastic.'  The 'Deterministic' category further subdivides into two methods: 'Greedy Search' and 'Beam...', represented by rectangular boxes connected by downward-pointing arrows indicating the flow of information or hierarchical relationship.  Similarly, the 'Stochastic' category branches into three methods: 'Multinomial...', 'Top-k...', and 'Top-p...', each also depicted as rectangular boxes connected by downward-pointing arrows.  The overall structure shows a top-down breakdown of text generation approaches, classifying them as either deterministic or stochastic and then further specifying individual techniques within each category.  The ellipses (...) after some method names suggest further details or parameters are omitted for brevity.
Figure 20: Common methods for generating text

Deterministic methods

Deterministic methods such as beam search work well for tasks with short, predictable text lengths. However, they are less effective for open-ended generation, such as dialogue, where output length varies. Let's examine the common issues that arise when using deterministic methods like greedy search or beam search to generate text from LLMs.

Greedy search selects the token with the highest probability at each step of the generation process.

Image represents a directed graph illustrating word probabilities in a sentence.  A thick, solid horizontal line labeled 'How' connects to a box labeled '0.56' representing the word 'are'.  From '0.56', a thick solid line labeled 'you' connects to a box labeled '0.91'.  From '0.91', a thick solid line labeled 'doing' connects to a box labeled '0.39'.  Dashed lines represent weaker connections.  A dashed line labeled 'come' connects 'How' to a box labeled '0.14'.  A dashed line labeled 'am' connects '0.56' to a box labeled '0.03'. A dashed line labeled 'dog' connects '0.56' to a box labeled '0.01'. A dashed line labeled 'do' connects 'How' to a box labeled '0.26'. A dashed line labeled 'work' connects '0.91' to a box labeled '0.001'. A dashed line labeled '?' connects '0.91' to a box labeled '0.38'.  Each box contains a numerical value, presumably representing the probability of that word appearing in the context of the sentence, given the preceding words.  The graph visually depicts the probabilistic relationships between words in a sentence structure.
Figure 21: Greedy search

While this method is straightforward and often produces coherent text, it has two major issues:

  • Repetition
  • Suboptimal generation

Repetition: When we use greedy search to select the next tokens, the text quickly starts repeating. This is because the model sometimes gets "stuck" in loops, reusing the same sequence of tokens. This happens when the model identifies certain words following each other with high probability.

Image represents a simplified diagram illustrating a Large Language Model (LLM) interaction.  A peach-colored rectangle labeled 'LLM' is the central component.  Below the rectangle, the text 'How is the weather?' acts as an input prompt to the LLM.  An upward-pointing arrow connects this prompt to the LLM, indicating the flow of information into the model.  Above the rectangle, the text 'The weather today is sunny. The weathe...' represents the output generated by the LLM in response to the input prompt. An upward-pointing arrow connects the LLM to this output, showing the information flow from the model. The overall diagram depicts a basic question-answering process where a user's query ('How is the weather?') is processed by the LLM, resulting in a textual response ('The weather today is sunny...').
Figure 22: Repetitive output

Suboptimal generation: Greedy search ignores alternative paths during the text generation. It might miss a high-probability sequence of tokens hidden behind a low-probability token.

Beam search improves upon greedy search by considering multiple sequences simultaneously. At each step, it keeps track of the top k sequences, where k is configurable.

Image represents a probabilistic context-free grammar (PCFG) tree illustrating word probabilities in a sentence.  A central node labeled 'How' branches into three main paths, each representing a different word choice: 'come,' 'are,' and 'do.'  These words are connected with thick lines to subsequent nodes containing probabilities (0.24, 0.31, and 0.26 respectively). Each of these nodes further branches out via thinner, dashed lines to other words, representing possible continuations of the sentence. For example, the 'are' node connects to 'plants,' 'animals,' 'you,' and 'am' with associated probabilities (0.001, 0.21, 0.36, and 0.03 respectively for the 'are' node). Similarly, the 'come' node connects to 'are' and 'plants' with probabilities (0.24 and 0.21 respectively). The 'do' node connects to 'you,' 'the,' and 'people' with probabilities (0.63, 0.001, and 0.21 respectively).  The numbers within the boxes represent the conditional probability of a word given its parent node in the tree.  The overall structure shows the branching possibilities and associated probabilities of different word sequences, suggesting a language model's prediction of word choices based on preceding words.
Figure 23: Beam search with a beam width of 3

Beam search allows for more exploration and produces higher-quality text than greedy search. However, it can struggle with open-ended generation. Two common issues with beam search are:

  • Inefficiency
  • Repetition

Inefficiency: Beam search can be computationally inefficient because it requires evaluating multiple sequences at once, which can slow down the generation process.

Repetition: Beam search can lead to repetitive and generic responses. It sometimes gets stuck in a loop and repeats common phrases.

So far, we've seen that deterministic methods struggle with repetition and do not work well for text generation. Let's explore stochastic methods, which are more commonly used for text generation in LLMs.

Stochastic methods

Stochastic methods generate text by introducing randomness. This randomness makes them more suitable for open-ended text generation. Three popular stochastic methods are:

  • Multinomial sampling
  • Top-k sampling
  • Top-p (nucleus) sampling
Multinomial sampling

Multinomial sampling selects the next token based on the probability distribution of the model's predictions. Each token has a probability associated with it, and a token is chosen based on these probabilities.

Image represents a bar chart illustrating the probability distribution of choosing different auxiliary verbs ('are,' 'is,' 'do,' 'should,' 'hard') to complete a sentence fragment, likely within a natural language processing context.  The horizontal axis displays the auxiliary verbs, with their corresponding probabilities shown as bar heights.  'are' has the highest probability (37%), followed by 'is' (21%), 'do' (12%), 'should' (4%), 'hard' (2%), and '0.13%' representing progressively lower probabilities, indicated by an ellipsis suggesting further, less likely options.  Two dashed arrows highlight the most probable choices: one points to 'are' with the text 'Choose 'are' with...', indicating its high likelihood of selection; the other points to 'hard' with the text 'Choose 'hard' with 2...', suggesting a less likely but still considered option. The bottom annotation,  `$P(w I \text{'How'})$`, likely represents a conditional probability formula, indicating the probability of choosing a word *w* given the context 'How'.
Figure 24: Multinomial sampling

This approach ensures a wide variety of possible outputs. However, it introduces a significant amount of randomness, especially when the probability distribution is flat. This randomness often results in generations that are not coherent. For example, the generated text shown in Figure 25 is the output of the GPT-2 model using multinomial sampling.

Image represents a simple generative model architecture.  At the bottom is a rectangular box labeled 'Model,' representing a language model.  Above it, the text 'Multinomial samp...' indicates that the model uses multinomial sampling to generate text. A vertical arrow connects the 'Model' box to a large rectangular box at the top containing the text 'I enjoy walking with my cute dog for the rest of t...', which represents the generated text output. The arrow points upwards, showing the flow of information from the model to the generated text.  The ellipsis ('...') suggests that the generated text is truncated and continues beyond what's shown.  The overall diagram illustrates a basic text generation process where the model produces text through multinomial sampling.
Figure 25: GPT-2 output using multinomial sampling

Due to coherence issues, multinomial sampling is rarely used in LLMs for text generation.

Top-k sampling

Top-k sampling [30] is a more advanced method that selects from the k most likely tokens rather than sampling from the entire distribution.

Image represents a bar chart illustrating the probability distribution of different words given the context 'How'.  The horizontal axis displays a series of words: 'are,' 'is,' 'do,' 'should,' 'hard,' and an ellipsis indicating further words. The vertical axis represents probability, ranging from 0.0 to 1.0.  The height of each bar corresponds to the probability of that word following 'How'.  The bars are ordered from highest to lowest probability: 'are' (37%), 'is' (21%), 'do' (12%), 'should' (4%), 'hard' (2%), and '0.13%' representing a much lower probability. A dashed line encloses the three highest-probability words ('are,' 'is,' and 'do'). A curved dashed arrow extends from this enclosed area to the right, pointing to the text 'Sample from top three...', indicating that a selection is made from these top three words.  The bottom of the image shows a mathematical formula:  `$P(w I \text{'How'})$`, which likely represents the conditional probability of word *w* given the context 'How'.
Figure 26: Example of top-k sampling with k=3

Here is a step-by-step process to select the next token in top-k sampling:

  1. The model predicts the probability distribution for the next token, providing a probability for each token in the vocabulary.
  2. The tokens are sorted in descending order based on their predicted probabilities.
  3. The top k tokens with the highest probabilities are considered for sampling.
  4. The probabilities of the top k tokens are normalized to ensure they sum to 1.
  5. A token is sampled from this normalized distribution.
Image represents a simple generative model architecture.  At the bottom is a rectangular box labeled 'Model,' representing a language model or similar generative AI.  Above this box, a vertical arrow points upwards, labeled 'Top-k sampling (k=50),' indicating that the model's output is processed using this sampling technique, where only the top 50 most probable next words are considered.  At the very top is a rectangular box containing the text 'I enjoy walking with my cute dog for the rest of...', representing an input prompt or a partial text sequence fed into the 'Model.' The arrow indicates that the output of the 'Model' after Top-k sampling is the continuation of the input text.  The value 'k=50' specifies that the Top-k sampling algorithm considers only the 50 most likely word candidates at each step of text generation.
Figure 27: GPT-2 output using top-k sampling with k=50

Top-k sampling balances coherence and diversity by picking from the top-k tokens. This reduces the chance of choosing irrelevant tokens while allowing some randomness. GPT-2 initially used top-k sampling, which was crucial to its success and popularity.

A major limitation of top-k sampling is that it always picks from a fixed number of top tokens. This is problematic depending on how the predicted probabilities are spread out. Let’s understand why.

Predicted token probabilities can be sharply or evenly distributed. With a sharp distribution, limiting choices to a fixed number of top tokens can produce nonsensical results because the model might miss the best choice. Conversely, with a flat distribution, this fixed limit restricts the model's creativity by not considering enough word options. For example, as shown in Figure 28, the model is 89% confident that the next token should be “lot,” but top-k sampling still considers "much" and "high" as possible next tokens to sample from.

Image represents a bar chart illustrating the probability distribution of a word (or n-gram) within a corpus, specifically focusing on the top three most frequent words. The horizontal axis displays different words: 'lot,' 'much,' 'high,' 'where,' 'this,' and an ellipsis indicating further words with lower probabilities. The vertical axis represents probability, ranging from 0.0 to 1.0.  The chart shows a tall bar for 'lot' representing 89% probability, followed by progressively shorter bars for 'much' (4%), 'high' (3%), 'where' (2%), 'this' (1%), and '0.13%' for the last visible word. A dashed line encloses the bars representing the top three most frequent words ('lot,' 'much,' 'high'). A curved dashed arrow extends from this enclosed area to the right, pointing to the text 'Sample from top three mos...', indicating that the enclosed area represents a sample from the top three most frequent words.  Below the chart, the formula '$P(w I \text{'Thanks a'})$' suggests the probability of word *w* given the preceding words 'Thanks a'.
Figure 28: Top-k sampling in sharp distribution

This limitation is addressed in top-p sampling, which we examine next.

Top-p (nucleus) sampling

Top-p sampling [31], also known as nucleus sampling, was developed in 2019. This method dynamically adjusts the number of tokens considered based on their combined probabilities. Instead of sampling only from the most likely k tokens, it chooses from the smallest possible set of tokens whose cumulative probability exceeds the probability p. This provides a more flexible and adaptive approach compared to top-k sampling.

Image represents two bar charts illustrating the concept of top-p sampling in a language model.  Each chart displays the probability distribution of the next word given a preceding text prompt. The left chart shows the probability distribution for the prompt 'Thanks a...', with 'lot' having the highest probability (89%), followed by 'much' (4%), 'high' (3%), 'where' (2%), 'this' (1%), and others with probabilities less than 1%.  A dashed line encloses the bars representing the top-p selection, indicating that the model would likely select from these words based on their cumulative probability.  The right chart shows the probability distribution for the prompt 'How', with 'are' (31%), 'is' (29%), 'do' (24%), 'come' (7%), 'confident' (4%), and others having lower probabilities.  Similarly, a dashed line encloses the top-p selection, suggesting the model would choose from these words.  Both charts have a y-axis ranging from 0.0 to 1.0 representing probability, and an x-axis showing the potential next words and their associated probabilities.  The text '$P(w I \text{'Thanks a'...}$ and '$P(w I \text{'How'})$' below each chart indicates the conditional probability calculation being visualized, where 'P' represents probability, 'w' represents the next word, and 'I' represents the given text prompt.  Curved dashed arrows labeled 'Top-p sampling...' point from the dashed selection boxes to the right, indicating the sampling process.
Figure 29: Top-p sampling adaptively chooses tokens based on the probability distribution

Here is a step-by-step process to select the next token in top-p sampling:

  1. The model predicts the probability distribution for the next token, providing a probability for each token in the vocabulary.
  2. The tokens are sorted in descending order based on their predicted probabilities.
  3. Instead of selecting a fixed number of tokens, top-p sampling chooses the smallest possible set of tokens whose cumulative probability exceeds a threshold p.
  4. The probabilities of the selected tokens are normalized to ensure they sum to 1.
  5. A token is sampled from this normalized distribution.
Image represents a simple generative model architecture.  At the bottom is a rectangular box labeled 'Model,' representing a language model.  An upward-pointing arrow connects this box to a text label above it reading 'Top-p sampling (p=0.92),' indicating that the model's output is processed using top-p sampling with a probability threshold of 0.92. This sampling method selects the most probable words whose cumulative probability exceeds 0.92.  Finally, an upward-pointing arrow connects the sampling method to a rectangular box at the top containing the text 'I enjoy walking with my cute dog for the rest of...', which represents the generated text output of the model after the top-p sampling.  The overall diagram illustrates the flow of information from the model, through the top-p sampling process, to the final generated text.
Figure 30: GPT-2 output using top-p sampling with p=0.92

The top-p sampling is widely used in advanced LLMs to generate human-like text. This method ensures the text is coherent and contextually relevant by focusing on the most probable tokens while allowing some randomness.

While we covered the key aspects of different sampling methods, there are more details to each method. Two popular techniques often used in advanced sampling methods are:

  • Temperature
  • Repetition penalty
Temperature

Temperature is a parameter in sampling methods that controls the randomness of predictions during sampling. Mathematically, the temperature parameter T scales the logits (raw scores) of the model’s output before applying the softmax function to generate probabilities. The adjusted softmax formula with temperature is given by:

pi=exp(xi/T)jnexp(xj/T)p_i=\frac{\exp \left(x_i / T\right)}{\displaystyle\sum_j^n \exp \left(x_j / T\right)}

where:

  • xix_i are the logits (raw scores) for each possible output
  • TT is the temperature parameter
  • pip_i represents the probability of output ii after applying the softmax function

When T=1T=1, the softmax function operates normally. When T>1T>1, the model generates a more uniform probability distribution, making predictions more random and diverse. Higher temperatures increase randomness, which helps the model generate more creative outputs. If the model starts to stray off-topic or produce meaningless outputs, this indicates that the temperature has been set too high.

Conversely, when T<1T<1, the model's output becomes more deterministic, with the highest logit values having more influence on the final prediction. Lower temperatures reduce randomness, which is more suitable for tasks requiring precise answers, such as summarization or translation. If the model starts to repeat itself, this indicates that the temperature has been set too low. A temperature of 0 consistently leads to the same output, making the sampling deterministic.

Image represents a comparative visualization of histograms, each depicting a probability distribution.  Four separate histograms are presented, arranged horizontally. Each histogram is enclosed within a dashed-line box.  The x-axis of each histogram is implicitly defined and represents a range of values (not explicitly labeled), while the y-axis (also implicitly defined) represents the frequency or probability of those values. The histograms differ in their distributions, showing varying degrees of concentration and spread.  Each histogram is labeled at the bottom with 'Temperature = [value]', where the value is 0.0, 0.5, 2, and 5 respectively, indicating that the histograms likely represent probability distributions at different temperature settings. The height of the bars in each histogram corresponds to the probability or frequency of the values within the corresponding bin.  The overall image suggests an analysis of how a probability distribution changes with varying temperature parameters.
Figure 31: Different temperature values applied to the same logits
What are typical temperature values?

Most model providers set the permissible temperature range between 0 and 2. Figure 32 illustrates OpenAI's API reference for the temperature setting.

Image represents a rectangular box with a light gray border containing only the text 'temperature...' centrally aligned.  No other components, connections, or information flows are depicted within the box.  Below the box, a small caption reads 'Text is not SVG - cannot display,' indicating that the image is a placeholder or a failed attempt to render a more complex diagram, likely an SVG (Scalable Vector Graphics) file, which would have contained visual elements beyond simple text.  The overall impression is that the image is incomplete or a representation of missing data related to 'temperature.'
Figure 32: OpenAI’s temperature documentation [32]

In modern LLMs, the temperature parameter typically ranges from 0.1 to 1.5. When set beyond 1.5, outputs can become increasingly erratic and less coherent, which is undesirable. The optimal value depends on the desired behavior and is often determined empirically. The following table, created by [33], suggests possible temperature values for several use cases.

Use caseTemperatureTop-pDescription
Code generation0.20.1Generates code that adheres to established patterns and conventions. Output is more deterministic and focused. Useful for generating syntactically correct code.
Creative writing0.70.8Generates creative and diverse text for storytelling. Output is more exploratory and less constrained by patterns.
Chatbot responses0.50.5Generates conversational responses that balance coherence and diversity. Output is more natural and engaging.

Table 5: Empirical temperature and top-p ranges for different tasks

Repetition penalty

Similarly, applying a repetition penalty can explicitly reduce the likelihood of generating repetitive sequences of tokens. This can be achieved by terminating the generation when repetitive n-grams are detected (as controlled by the “no_repeat_ngram_size” parameter in Hugging Face models) or by directly modifying the probabilities of tokens that have already been sampled earlier in the sequence (such as the “frequency_penalty” in the ChatGPT API).

If you are interested in learning more about sampling methods in LLMs, refer to [30].

Evaluation

Offline evaluation metrics

Evaluating LLMs like ChatGPT requires more than traditional metrics such as perplexity. These models behave in complex ways and perform differently across various tasks. Therefore, we need to assess their skills in different tasks to ensure the model is both effective and safe.

In this section, we evaluate our LLM from the following perspectives:

  • Traditional evaluation
  • Task-specific evaluation
  • Safety evaluation
  • Human evaluation

Traditional evaluation

Traditional evaluation provides an initial understanding of the LLM’s performance using typical offline metrics. A common metric is perplexity, which measures how accurately the model predicts the exact sequence of tokens in the training data. A low perplexity value indicates that the model assigns higher probabilities, on average, to the tokens in the training data.

While these metrics are important for initial evaluation, they do not provide insights into an LLM's capabilities or limitations. For example, a low perplexity indicates the model is good at predicting the next tokens but does not measure its ability to understand code or solve math problems.

Task-specific evaluation

To effectively evaluate an LLM, we need to assess its performance across diverse tasks such as mathematics, code generation, and common-sense reasoning. This comprehensive approach helps identify the model’s strengths and weaknesses. Commonly used tasks for evaluating LLM’s capabilities are:

  • Common-sense reasoning
  • World knowledge
  • Reading comprehension
  • Mathematical reasoning
  • Code generation
  • Composite benchmarks
Common-sense reasoning

Common-sense reasoning evaluates a model's ability to make inferences based on everyday situations and general knowledge. It tests the model's understanding of basic human experiences, logical connections, and assumptions that people naturally make. Examples include interpreting idioms, understanding cause and effect in typical scenarios, and predicting likely outcomes in social situations.

Image represents a simple block diagram illustrating a basic input-output system, likely related to a generative AI model.  The diagram is enclosed within a dashed rectangular border.  The main body is divided into two sections by a vertical line. The left section, labeled 'Prompt...', represents the input area where a user would provide a prompt or query to the system. The right section, labeled 'Answer...', represents the output area where the system's response or generated content would be displayed. A single horizontal line connects the 'Prompt...' and 'Answer...' sections, indicating the flow of information from input to output.  The text 'Text is not SVG - cannot display' at the bottom indicates that the image itself is a placeholder and not a functional representation of a system.
Figure 33: Example of common-sense reasoning

Typical benchmarks for common-sense reasoning are PIQA (Physical Interaction QA) [34], SIQA [35], HellaSwag [36], WinoGrande [37], OpenBookQA [38], and CommonsenseQA [39], each of which focuses on different aspects. For example, the CommonsenseQA benchmark is a multiple-choice question dataset for which the questions require common-sense knowledge to answer. PIQA focuses on reasoning about physical interactions in everyday situations, while HellaSwag focuses on everyday events.

World knowledge

World knowledge refers to the model's factual knowledge about the world, including historical facts, scientific information, geography, and current events. An example would be answering questions about significant historical events or scientific principles.

Image represents a simple system diagram depicting a basic input-output process.  The diagram is enclosed within a dashed rectangular border.  The interior is divided into two sections by a vertical line. The left section, labeled 'Prompt...', represents the input to the system. A horizontal line connects the 'Prompt...' section to the right section, indicating the flow of information. The right section, labeled 'Answer...', represents the output generated by the system in response to the input.  No internal components or processes are shown within the system; only the input and output are explicitly represented. The text 'Text is not SVG - cannot display' is present at the bottom, indicating that the image is a placeholder and lacks detailed internal representation.
Figure 34: World knowledge example

Common benchmarks for this task include:

  • TriviaQA [40]: Questions are gathered from trivia and quiz-league websites.
  • Natural Questions (NQ) [41]: A dataset from Google that includes questions and answers found in natural web queries.
  • SQuAD (Stanford Question Answering Dataset) [42]: Contains questions based on Wikipedia articles.
Reading comprehension

Reading comprehension tasks evaluate a model's ability to understand and interpret text passages and answer questions based on them. This is critical for assessing a model’s ability to extract information from and reason in relation to given texts.

Image represents a simple diagram illustrating a basic input-output system, likely related to a large language model (LLM) or similar generative AI.  The diagram is enclosed within a dashed rectangular border.  This border is divided into two equal sections by a vertical line. The left section is labeled 'Prompt...' indicating the input area where a user would provide a text prompt or query. The right section is labeled 'Answer...', representing the output area where the system would generate a response. A single horizontal line connects the top of both sections, suggesting a unified system.  No explicit connections or data flow arrows are shown, implying a direct, implicit relationship between the prompt input and the answer output. The text 'Text is not SVG - cannot display' at the bottom indicates that the image is a static representation and not an interactive SVG.
Figure 35: Reading comprehension example from SQuAD

Typical benchmarks for reading comprehension are SQuAD [42], QuAC [43], and BoolQ [44].

Mathematical reasoning

Mathematical reasoning tasks evaluate a model's ability to solve mathematical problems.

Image represents a simple block diagram illustrating a basic input-output system, likely depicting a generative AI model.  The diagram is enclosed within a dashed-line rectangle.  The rectangle is horizontally divided into two sections by a solid line. The left section, labeled 'Prompt...', represents the input to the system, where a user would provide a text prompt or query. The right section, labeled 'Answer...', represents the output of the system, where the AI's generated response would appear.  A vertical line separates the input and output sections, visually suggesting a processing step occurring between them, although no internal components or processes are explicitly shown.  The information flow is unidirectional, from the 'Prompt...' input to the 'Answer...' output.  The overall simplicity suggests a high-level overview of the system, focusing solely on the input and output without detailing the internal workings of the AI model.
Figure 36: Mathematical reasoning example from GSM8K [45]

Two common benchmarks for mathematical reasoning tasks are:

  • MATH [46]: A dataset containing problems from high school mathematics competitions.
  • GSM8K (Grade School Math 8K) [45]: A dataset with grade school math problems to test the model's problem-solving skills.
Code generation

Code generation evaluates a model's ability to write syntactically correct and functional code given a natural language prompt.

Image represents a simple diagram illustrating a prompt-response interaction, likely within a code generation or large language model context.  The diagram is divided into two main rectangular sections by a vertical line. The left section, labeled 'Prompt...', is a blank space representing the input prompt given to the system. The right section, labeled 'Answer', contains the text 'def is_prime(n):...', indicating the system's response, which appears to be the beginning of a Python function definition for checking if a number is prime.  The two sections are separated by a horizontal line, suggesting a clear input-output relationship. The entire diagram is enclosed within a dashed-line rectangle, further emphasizing the system's boundaries.  No explicit data flow arrows are shown, but the implied flow is from the 'Prompt...' section to the 'Answer' section, representing the processing of the input prompt to generate the output response.
Figure 37: Code generation example from HumanEval

Common benchmarks for code generation are:

  • HumanEval [47]: Python coding tasks.
  • MBPP (MultiPL-E Benchmarks for Programming Problems) [48]: Multiple programming language tasks to evaluate multilingual code generation capabilities.
Composite benchmarks

In addition to specific benchmarks described above, composite benchmarks combine multiple tasks for a broader assessment. Popular composite benchmarks are:

  • MMLU (Massive Multitask Language Understanding) [49]: Consists of multiple-choice questions from a wide range of subjects including humanities, STEM, social sciences, and more, at various difficulty levels.
  • MMMU (Massive Multilingual Multitask Understanding) [50]: MMMU includes a wide range of multiple-choice questions covering many subjects with varying levels of difficulty. Unlike MMLU, which focuses on English, MMMU tests the models’ ability to understand and generate accurate responses across different languages, assessing not only multilingual capabilities but also reasoning and cross-cultural knowledge.
  • AGIEval [51]: A comprehensive benchmark designed to test artificial general intelligence across multiple domains and tasks.
  • Meta Llama 3 human evaluation [13]: A high-quality human evaluation set containing 1,800 prompts that cover 12 key use cases: asking for advice, brainstorming, classification, closed question answering, coding, creative writing, extraction, inhabiting a character/persona, open question answering, reasoning, rewriting, and summarization.

To summarize, we use various tasks and benchmarks to evaluate LLMs’ task-specific performance. This evaluation covers understanding and generating human-like responses across various domains. However, evaluation doesn't stop there. Safety evaluation is crucial for the responsible deployment of these models. Let's look deeper.

Safety evaluation

Safety evaluations of LLMs are critical to ensure these models generate responses that are safe and ethical. These evaluations focus on various tasks that help identify and mitigate risks such as the generation of harmful content. The main aspects of safety evaluation include:

  • Toxicity and harmful content
  • Bias and fairness
  • Truthfulness
  • User privacy and data leakage
  • Adversarial robustness
Toxicity and harmful content

We evaluate a model's ability to avoid generating toxic content. Toxicity includes:

  • Hate speech
  • Abusive language
  • Content that may pose harm to individuals, groups, or society
  • Content useful for planning attacks or violence
  • Instructions for finding illegal content
Image represents a simple, high-level diagram illustrating a basic input-output system, possibly related to a large language model or similar generative AI.  The diagram is divided into two equal rectangular boxes by a horizontal and a vertical line, creating four quadrants.  The top-left quadrant contains the label 'Prompt...' indicating an input area where a user would provide a prompt or query. The bottom-right quadrant is labeled 'Answer...', representing the output area where the system's response or generated content would appear.  A horizontal line connects the 'Prompt...' and 'Answer...' labels, visually representing the flow of information from input to output. The dashed lines around the entire diagram suggest a system boundary.  No other components, connections, or data flow details beyond this basic input-output relationship are depicted.
Figure 38: Model’s expected response to toxic prompts

Commonly used benchmarks to evaluate a model's toxicity are:

  • RealToxicityPrompts [52]: Consists of about 100,000 prompts that the model must complete; then a toxicity score is automatically evaluated using PerspectiveAPI [53].
  • ToxiGen [54]: This benchmark tests a model's ability to avoid generating discriminatory language.
  • HateCheck [55]: A suite of tests specifically for hate speech detection, covering various types of hate speech.

Evaluating models using these benchmarks helps identify potential risks and improve the ability of models to generate safe and respectful content.

Bias and fairness

We assess the model's responses for potential biases. This includes detecting gender, racial, and other biases in generated content.

Typical benchmarks are:

  • CrowS-Pairs [56]: Contains paired sentences differing in only one attribute (e.g., gender) to test for bias. This dataset enables measuring biases in 9 categories: gender, religion, race/color, sexual orientation, age, nationality, disability, physical appearance, and socioeconomic status.
  • BBQ [57]: A dataset of hand-written question sets that target attested social biases against different socially relevant categories.
  • BOLD [58]: A large-scale dataset that consists of 23,679 English text generation prompts for bias benchmarking across five domains.

These benchmarks help us ensure the model treats all demographic groups fairly and equally.

Truthfulness

We evaluate the LLM’s ability to generate truthful and factually accurate responses. This includes distinguishing between factual information and common misconceptions or falsehoods.

Image represents a simple diagram illustrating a basic input-output model, likely for a generative AI system.  The diagram is divided into two equal-sized rectangular boxes by a horizontal line, further subdivided into two equal-sized boxes by a vertical line. The top-left box is labeled 'Prompt...', indicating it's where user input or prompts are entered. The bottom-left box is empty, implying it's where the processed prompt would be stored or displayed.  The top-right box is empty, suggesting it's a placeholder for internal processing steps. The bottom-right box is labeled 'Answer...', indicating it's where the AI's generated response or output is displayed.  The boxes are outlined with dashed lines, suggesting a conceptual representation rather than a detailed architectural diagram.  There are no explicit connections shown between the boxes, implying a simplified representation of the data flow from prompt to answer.  The text 'Text is not SVG...cannot display' at the bottom indicates a limitation in displaying the image's content.
Figure 39: Truthfulness example from TruthfulQA

A common benchmark to evaluate truthfulness is TruthfulQA [59]. It measures the truthfulness of a model, i.e., its ability to identify when a claim is true. This benchmark can evaluate the risks of a model generating misinformation or false claims.

User privacy and data leakage

We evaluate the LLM’s tendency to leak sensitive information that it may have been exposed to during training. Since LLMs are trained on various publicly available data sources, they might know about people with a public internet presence. These assessments ensure that LLMs do not inadvertently disclose personal information. A common benchmark for this purpose is PrivacyQA [60].

Adversarial robustness

Adversarial robustness tests an LLM’s ability to handle inputs intentionally designed to confuse or trick the model. This is crucial for ensuring the model's reliability and safety in practice. Typical benchmarks for testing LLM’s adversarial robustness include AdvGLUE [61], TextFooler [62], and AdvBench [63].

To summarize, we use various benchmarks to evaluate the safety of the LLM, which is crucial to ensure users’ safety. While both task-specific and safety evaluations are essential, human evaluation remains the most reliable method for comprehensive assessment.

Human evaluation

In this approach, human evaluators are asked to rate the different aspects of an LLM such as helpfulness and safety. Human evaluation is critical for assessing nuanced aspects of helpfulness and safety that task-specific and safety benchmarks might miss.

Online evaluation metrics

Online evaluation metrics measure how an LLM performs when deployed in production. Commonly used metrics are:

  • User feedback and ratings
  • User engagement
  • Conversion rate
  • Online leaderboards

User feedback and ratings: Users can rate their satisfaction with the model's responses. This direct feedback from users highlights areas that need improvement.

Image represents a simple system illustrating user interaction with a question-answering system.  The system displays a question, 'Where is the capital of Franc...', within a rounded rectangle.  To the left, another rounded rectangle contains the answer, 'Paris.'. A dashed arrow originates near the 'Paris' answer, curves downwards, and points to the text 'User feedback' indicating that the answer 'Paris' is considered user feedback to the system.  The overall structure suggests a feedback loop where the system provides an answer ('Paris') and this answer is then used as feedback to improve the system's performance or understanding.  No URLs or parameters are visible.
Figure 40: Collecting users’ feedback

User engagement: Metrics such as “number of queries made” and “session duration” can be insightful signals to measure user engagement. High engagement levels often indicate the LLM is effective in providing helpful information.

Conversion rate: Conversion rate refers to the percentage of users who make a purchase or sign up for a service after interacting with the LLM. Conversion rate is a crucial metric to monitor because higher conversion rates indicate that users find the LLM useful enough to pay for the service.

Online leaderboards: Online leaderboards track the performance of various LLMs in real time, A notable example is LMSYS Chatbot Arena [64], a crowdsourced open platform designed to evaluate LLMs. These models are ranked based on more than 800,000 human pairwise comparisons.

Image represents a leaderboard ranking different large language models (LLMs).  The table is organized into columns representing:  'Rank (UB)' indicating the model's position, 'Model' listing the name and version of each LLM (e.g., `o1-preview`, `ChatGPT-40-latest (2024-09-03)`, `Gemini-1.5-Pro-Exp-0827`), 'Arena Score' providing a numerical score for each model's performance, '95% CI' showing the 95% confidence interval around the Arena Score (e.g., '+6/-7'), 'Votes' indicating the number of votes contributing to the score, and 'Organization' specifying the developer of each LLM (e.g., OpenAI, Google, xAI, Anthropic).  The rows represent individual LLMs, ordered by their Arena Score in descending order.  No explicit information flow is depicted; the table simply presents comparative performance data for various LLMs.
Figure 41: Chatbot Arena leaderboard

Overall ML System Design

Designing a chatbot system such as ChatGPT requires several components working together smoothly. Unlike traditional models, this system combines multiple services and pipelines to ensure efficiency, safety, and continuous improvement. In this section, we'll explore two key pipelines:

  • Training pipeline
  • Inference pipeline

Training pipeline

The training pipeline involves three critical stages: pretraining, SFT, and RLHF. These stages collectively ensure the model is capable and that it generates helpful and safe responses.

Inference pipeline

The inference pipeline includes several components that ensure the safety, relevance, and quality of the generated responses. This pipeline is responsible for real-time interaction with users. The key components in the inference pipeline are:

  • Safety filtering
  • Prompt enhancer
  • Response generator
  • Response safety evaluator
  • Rejection response generator
  • Session management
Image represents a flowchart depicting the process of generating a response from a language model.  The process begins with a 'Text prompt' which is fed into a 'Safety Filt...' block.  A diamond-shaped decision node labeled 'Safe?' determines if the prompt is safe; if yes, the prompt proceeds to a 'Prompt...' block, otherwise, a 'Rejection Response...' is generated.  The 'Prompt...' block feeds into a 'Response...' block, which utilizes 'Top-p sampling' and interacts with a 'Trained...' (presumably the language model) cloud component.  The 'Response...' output is then checked for safety in another 'Safe...' decision node. If safe, a 'Generated response' is output; otherwise, the process returns to the 'Rejection Response...'.  The entire process is managed by a 'Session Management' block, which receives the final generated response and likely handles session-related data.  The connections between blocks are represented by arrows indicating the flow of information.
Figure 42: Chatbot overall design

Let’s explore each component in more detail.

Safety filtering

This component analyzes the user prompt to detect harmful, inappropriate, or unsafe queries before it is processed by the model. For example, a prompt asking for instructions on creating a harmful device will be rejected and flagged.

Prompt enhancer

The prompt enhancer component refines and enriches the input prompt to make it more informative and detailed. It expands acronyms, corrects misspellings, and adds context where necessary.

Image represents a simple data flow diagram illustrating a prompt enhancement process.  The diagram shows an initial prompt, 'Tell me about NYC.', as input. This prompt flows rightward via a black arrow into a rectangular box labeled 'Prompt Enhancer,' which is light purple with a darker purple border.  The 'Prompt Enhancer' box represents a process that modifies or improves the input prompt.  From the 'Prompt Enhancer' box, another black arrow points rightward to the output, which is a more detailed and enhanced prompt: 'Tell me about New York City (NYC), h...'.  The ellipsis ('...') suggests that the output prompt is longer than what's fully displayed, implying the 'Prompt Enhancer' added information such as context or keywords.  The overall flow demonstrates a transformation of a concise prompt into a more comprehensive one suitable for a downstream process (not shown).
Figure 43: Example of prompt enhancement

This component ensures the text prompts are clear, unambiguous, and free of grammatical errors before passing it to the model, which helps the model generate better responses.

Response generator

The response generator interacts with the trained LLM and utilizes top-p sampling to generate a helpful response. This component can optionally use other techniques to improve the quality and safety of the generated response. For example, it might generate multiple possible responses and then choose the one that is more appropriate based on a set of predefined criteria.

Response safety evaluator

This component evaluates the generated response to detect harmful or inappropriate content before it is shown to the user. It acts as a final safeguard to ensure responses meet ethical and safety standards.

Rejection response generator

This component generates a proper response when the input prompt is unsafe or the generated response is unsuitable. It provides a clear and polite explanation of why the request cannot be fulfilled.

Session management

To maintain conversation context and handle follow-up questions effectively, specific handling is required. For example, when a user is chatting with a model about their favorite movies, the model needs to remember not just the current question, but also their previous mentions of different genres or films.

This component maintains the continuity and coherence of the conversation by tracking the chat history and managing the flow of dialogue. This is achieved by feeding the chat history, along with the enhanced prompt, into the response generator. This design ensures that each response is contextually relevant by referencing previous interactions and appropriately handling the state of the conversation.

Other Talking Points

If there is extra time at the end of the interview, here are some additional talking points:

  • Techniques for managing dialogue states and tracking context across multiple turns [65].
  • Employing advanced or more efficient ML objectives such as multi-token prediction [66].
  • Handling very long sequence lengths [67][68].
  • How to develop multimodal LLMs [69][70].
  • Techniques such as RAG to leverage external knowledge bases and databases to enhance LLM output [71]. We explore this in Chapter 6.
  • Efficiency techniques (e.g., distillation) for faster text generation.
  • Techniques for adapting LLMs to specific domains (e.g., customer service, healthcare) without forgetting previous knowledge [72].
  • Addressing security and privacy concerns in LLMs.
  • Different optimization algorithms such as PPO, DPO, and rejection sampling [73].
  • Red-teaming LLMs to reduce harm [74].
  • Super-alignment and its importance in developing LLMs [75].
  • How in-context learning works [76].
  • Grouped query attention and its benefits [77].
  • Employing chain-of-thought prompting techniques [78]. We explore this in Chapter 6.
  • Implementing KV cache [79].
  • Enhancing trust by requiring models to produce clear and verifiable justifications for their outputs [80].

Summary

Image represents a mind map summarizing the design of a generative AI system.  The central node is labeled 'Summary,' branching out into several main categories.  These include 'Caching mechanisms,' detailing aspects like specification (input and output), ML upscaling, and Docker/clip; 'Jobs processing,' encompassing web-crawling, URL and language identification, content quality filtering, inappropriate content removal, and data quality assurance; 'Architecture,' focusing on positional encoding, Transformer architecture, and prediction tasks; 'Training,' covering data generation, instruction data, vector prediction, training a large model, calibration, and sampling methods (deterministic, beam search, MCMC, top-k, top-p, temperature); 'Traditional metrics,' encompassing common-sense reasoning, world knowledge, reading comprehension, and code generation benchmarks; 'CI/CD,' including compilable benchmarks, triviality level validation, traffic noise, and data leakage/privacy issues; and 'Human evaluation,' covering user feedback, user comparisons, online leaderboards, and training updates.  Finally, a 'System components' branch details overall system components and other traffic mechanisms, while a 'Solicitor to' branch describes prompt parameters, response parameters, and response quality evaluation.  Each branch further subdivides into more specific sub-topics, creating a hierarchical structure illustrating the various components and their interrelationships within the generative AI system.

Reference Material

[1] OpenAI’s ChatGPT. https://openai.com/index/chatgpt/.
[2] ChatGPT wiki. https://en.wikipedia.org/wiki/ChatGPT.
[3] OpenAI’s models. https://platform.openai.com/docs/models.
[4] Google’s Gemini. https://gemini.google.com/.
[5] Meta’s Llama. https://llama.meta.com/.
[6] Beautiful Soup. https://beautiful-soup-4.readthedocs.io/en/latest/.
[7] Lxml. https://lxml.de/.
[8] Document Object Model. https://en.wikipedia.org/wiki/Document_Object_Model.
[9] Boilerplate removal tool. https://github.com/miso-belica/jusText.
[10] fastText. https://fasttext.cc/.
[11] langid. https://github.com/saffsd/langid.py.
[12] RoFormer: Enhanced Transformer with Rotary Position Embedding. https://arxiv.org/abs/2104.09864.
[13] Llama 3 human evaluation. https://github.com/meta-llama/llama3/blob/main/eval_details.md.
[14] Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. https://arxiv.org/abs/1910.10683.
[15] DeBERTa: Decoding-enhanced BERT with Disentangled Attention. https://arxiv.org/abs/2006.03654.
[16] Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. https://arxiv.org/abs/2006.16236.
[17] Common Crawl. https://commoncrawl.org/.
[18] C4 dataset. https://www.tensorflow.org/datasets/catalog/c4.
[19] Stack Exchange dataset. https://github.com/EleutherAI/stackexchange-dataset.
[20] Training language models to follow instructions with human feedback. https://arxiv.org/abs/2203.02155.
[21] Alpaca. https://crfm.stanford.edu/2023/03/13/alpaca.html.
[22] Dolly-15K. https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm.
[23] Introducing FLAN: More generalizable Language Models with Instruction Fine-Tuning. https://research.google/blog/introducing-flan-more-generalizable-language-models-with-instruction-fine-tuning/.
[24] Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback. https://arxiv.org/abs/2204.05862.
[25] Proximal Policy Optimization Algorithms. https://arxiv.org/abs/1707.06347.
[26] Direct Preference Optimization: Your Language Model is Secretly a Reward Model. https://arxiv.org/abs/2305.18290.
[27] Illustrating RLHF. https://huggingface.co/blog/rlhf.
[28] RLHF progress and challenges. https://www.youtube.com/watch?v=hhiLw5Q_UFg.
[29] State of GPT. https://www.youtube.com/watch?v=bZQun8Y4L2A.
[30] Different sampling methods. https://huggingface.co/blog/how-to-generate.
[31] The Curious Case of Neural Text Degeneration. https://arxiv.org/abs/1904.09751.
[32] OpenAI’s API reference. https://platform.openai.com/docs/api-reference/chat/create.
[33] Cheat Sheet: Mastering Temperature and Top_p in ChatGPT API. https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api/172683.
[34] PIQA: Reasoning about Physical Commonsense in Natural Language. https://arxiv.org/abs/1911.11641.
[35] SocialIQA: Commonsense Reasoning about Social Interactions. https://arxiv.org/abs/1904.09728.
[36] HellaSwag: Can a Machine Really Finish Your Sentence? https://arxiv.org/abs/1905.07830.
[37] WinoGrande: An Adversarial Winograd Schema Challenge at Scale. https://arxiv.org/abs/1907.10641.
[38] Can a Suit of Armor Conduct Electricity? A New Dataset for Open Book Question Answering. ​​https://arxiv.org/abs/1809.02789.
[39] CommonsenseQA: A Question Answering Challenge Targeting Commonsense Knowledge. https://arxiv.org/abs/1811.00937.
[40] TriviaQA: A Large Scale Dataset for Reading Comprehension and Question Answering. https://nlp.cs.washington.edu/triviaqa/.
[41] The Natural Questions Dataset. https://ai.google.com/research/NaturalQuestions.
[42] SQuAD: 100,000+ Questions for Machine Comprehension of Text. https://arxiv.org/abs/1606.05250.
[43] QuAC dataset. https://quac.ai/.
[44] BoolQ: Exploring the Surprising Difficulty of Natural Yes/No Questions. https://arxiv.org/abs/1905.10044.
[45] GSM8K dataset. https://github.com/openai/grade-school-math.
[46] MATH dataset. https://github.com/hendrycks/math/.
[47] HumanEval dataset. https://github.com/openai/human-eval.
[48] MBPP dataset. https://github.com/google-research/google-research/tree/master/mbpp.
[49] Measuring Massive Multitask Language Understanding. https://arxiv.org/abs/2009.03300.
[50] Measuring Massive Multilingual Multitask Language Understanding. https://huggingface.co/datasets/openai/MMMLU.
[51] AGIEval: A Human-Centric Benchmark for Evaluating Foundation Models. https://arxiv.org/abs/2304.06364.
[52] RealToxicityPrompts: Evaluating Neural Toxic Degeneration in Language Models. https://arxiv.org/abs/2009.11462.
[53] Perspective API. https://perspectiveapi.com/.
[54] ToxiGen: A Large-Scale Machine-Generated Dataset for Adversarial and Implicit Hate Speech Detection. https://arxiv.org/abs/2203.09509.
[55] HateCheck: Functional Tests for Hate Speech Detection Models. https://arxiv.org/abs/2012.15606.
[56] CrowS-Pairs: A Challenge Dataset for Measuring Social Biases in Masked Language Models. https://arxiv.org/abs/2010.00133.
[57] BBQ: A Hand-Built Bias Benchmark for Question Answering. https://arxiv.org/abs/2110.08193.
[58] BOLD: Dataset and Metrics for Measuring Biases in Open-Ended Language Generation. https://arxiv.org/abs/2101.11718.
[59] TruthfulQA: Measuring How Models Mimic Human Falsehoods. https://arxiv.org/abs/2109.07958.
[60] Question Answering for Privacy Policies: Combining Computational and Legal Perspectives. https://arxiv.org/abs/1911.00841.
[61] AdvGLUE Benchmark. https://adversarialglue.github.io/.
[62] Is BERT Really Robust? A Strong Baseline for Natural Language Attack on Text Classification and Entailment. https://arxiv.org/abs/1907.11932.
[63] AdvBench. https://github.com/llm-attacks/llm-attacks.
[64] Chatbot Arena leaderboard. https://lmarena.ai/leaderboard.
[65] A Survey on Recent Advances in LLM-Based Multi-turn Dialogue Systems. https://arxiv.org/abs/2402.18013.
[66] Better & Faster Large Language Models via Multi-token Prediction. https://arxiv.org/abs/2404.19737.
[67] Gemini 1.5: Unlocking multimodal understanding across millions of tokens of context. https://arxiv.org/abs/2403.05530.
[68] HyperAttention: Long-context Attention in Near-Linear Time. https://arxiv.org/abs/2310.05869.
[69] MM-LLMs: Recent Advances in MultiModal Large Language Models. https://arxiv.org/abs/2401.13601.
[70] Multimodality and Large Multimodal Models. https://huyenchip.com/2023/10/10/multimodal.html.
[71] What is Retrieval-Augmented Generation? https://cloud.google.com/use-cases/retrieval-augmented-generation.
[72] How to Customize an LLM: A Deep Dive to Tailoring an LLM for Your Business. https://techcommunity.microsoft.com/t5/ai-machine-learning-blog/how-to-customize-an-llm-a-deep-dive-to-tailoring-an-llm-for-your/ba-p/4110204.
[73] Llama 2: Open Foundation and Fine-Tuned Chat Models. https://arxiv.org/abs/2307.09288.
[74] Red Teaming Language Models to Reduce Harms: Methods, Scaling Behaviors, and Lessons Learned. https://arxiv.org/abs/2209.07858.
[75] Introducing superalignment. https://openai.com/index/introducing-superalignment/.
[76] Language Models are Few-Shot Learners. https://arxiv.org/abs/2005.14165.
[77] GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. https://arxiv.org/abs/2305.13245.
[78] Chain-of-Thought Prompting Elicits Reasoning in Large Language Models. https://arxiv.org/abs/2201.11903.
[79] Efficiently Scaling Transformer Inference. https://arxiv.org/abs/2211.05102.
[80] Prover-Verifier Games improve legibility of language model outputs. https://openai.com/index/prover-verifier-games-improve-legibility/.

ask alexask alex expend