Skip to content

Text Classification using OpenAI and Pydantic

This tutorial showcases how to implement text classification tasks—specifically, single-label and multi-label classifications—using the OpenAI API and Pydantic models. If you want to see full examples check out the hub examples for single classification and multi classification

Motivation

Text classification is a common problem in many NLP applications, such as spam detection or support ticket categorization. The goal is to provide a systematic way to handle these cases using OpenAI's GPT models in combination with Python data structures.

Single-Label Classification

Defining the Structures

For single-label classification, we define a Pydantic model with a Literal field for the possible labels.

Literals vs Enums

We prefer using Literal types over enum for classification labels. Literals provide better type checking and are more straightforward to use with Pydantic models.

Few-Shot Examples

Including few-shot examples in the model's docstring is crucial for improving the model's classification accuracy. These examples guide the AI in understanding the task and expected outputs.

If you want to learn more prompting tips check out our prompting guide

Chain of Thought

Using Chain of Thought has been shown to improve the quality of the predictions by ~ 10%

from pydantic import BaseModel, Field
from typing import Literal
from openai import OpenAI
import instructor

# Apply the patch to the OpenAI client
# enables response_model keyword
client = instructor.from_openai(OpenAI())


class ClassificationResponse(BaseModel):
    """
    A few-shot example of text classification:

    Examples:
    - "Buy cheap watches now!": SPAM
    - "Meeting at 3 PM in the conference room": NOT_SPAM
    - "You've won a free iPhone! Click here": SPAM
    - "Can you pick up some milk on your way home?": NOT_SPAM
    - "Increase your followers by 10000 overnight!": SPAM
    """

    chain_of_thought: str = Field(
        ...,
        description="The chain of thought that led to the prediction.",
    )
    label: Literal["SPAM", "NOT_SPAM"] = Field(
        ...,
        description="The predicted class label.",
    )

Classifying Text

The function classify will perform the single-label classification.

def classify(data: str) -> ClassificationResponse:
    """Perform single-label classification on the input text."""
    return client.chat.completions.create(
        model="gpt-4o-mini",
        response_model=ClassificationResponse,
        messages=[
            {
                "role": "user",
                "content": f"Classify the following text: <text>{data}</text>",
            },
        ],
    )

Testing and Evaluation

Let's run examples to see if it correctly identifies spam and non-spam messages.

if __name__ == "__main__":
    for text, label in [
        ("Hey Jason! You're awesome", "NOT_SPAM"),
        ("I am a nigerian prince and I need your help.", "SPAM"),
    ]:
        prediction = classify(text)
        assert prediction.label == label
        print(f"Text: {text}, Predicted Label: {prediction.label}")
        #> Text: Hey Jason! You're awesome, Predicted Label: NOT_SPAM
        #> Text: I am a nigerian prince and I need your help., Predicted Label: SPAM

Multi-Label Classification

Defining the Structures

For multi-label classification, we'll update our approach to use Literals instead of enums, and include few-shot examples in the model's docstring.

from typing import List
from pydantic import BaseModel, Field


class MultiClassPrediction(BaseModel):
    """
    Class for a multi-class label prediction.

    Examples:
    - "My account is locked": ["TECH_ISSUE"]
    - "I can't access my billing info": ["TECH_ISSUE", "BILLING"]
    - "When do you close for holidays?": ["GENERAL_QUERY"]
    - "My payment didn't go through and now I can't log in": ["BILLING", "TECH_ISSUE"]
    """

    chain_of_thought: str = Field(
        ...,
        description="The chain of thought that led to the prediction.",
    )

    class_labels: List[Literal["TECH_ISSUE", "BILLING", "GENERAL_QUERY"]] = Field(
        ...,
        description="The predicted class labels for the support ticket.",
    )

Classifying Text

The function multi_classify is responsible for multi-label classification.

def multi_classify(data: str) -> MultiClassPrediction:
    """Perform multi-label classification on the input text."""
    return client.chat.completions.create(
        model="gpt-4o-mini",
        response_model=MultiClassPrediction,
        messages=[
            {
                "role": "user",
                "content": f"Classify the following support ticket: <ticket>{data}</ticket>",
            },
        ],
    )

Testing and Evaluation

Finally, we test the multi-label classification function using a sample support ticket.

# Test multi-label classification
ticket = "My account is locked and I can't access my billing info."
prediction = multi_classify(ticket)
assert "TECH_ISSUE" in prediction.class_labels
assert "BILLING" in prediction.class_labels
print(f"Ticket: {ticket}")
print(f"Predicted Labels: {prediction.class_labels}")
#> Ticket: My account is locked and I can't access my billing info.
#> Predicted Labels: ['TECH_ISSUE', 'BILLING']

By using Literals and including few-shot examples, we've improved both the single-label and multi-label classification implementations. These changes enhance type safety and provide better guidance for the AI model, potentially leading to more accurate classifications.