Text Classification using OpenAI and Pydantic¶
This tutorial showcases how to implement text classification tasks—specifically, single-label and multi-label classifications—using the OpenAI API and Pydantic models. If you want to see full examples check out the hub examples for single classification and multi classification
Motivation
Text classification is a common problem in many NLP applications, such as spam detection or support ticket categorization. The goal is to provide a systematic way to handle these cases using OpenAI's GPT models in combination with Python data structures.
Single-Label Classification¶
Defining the Structures¶
For single-label classification, we define a Pydantic model with a Literal field for the possible labels.
Literals vs Enums
We prefer using Literal
types over enum
for classification labels. Literals provide better type checking and are more straightforward to use with Pydantic models.
Few-Shot Examples
Including few-shot examples in the model's docstring is crucial for improving the model's classification accuracy. These examples guide the AI in understanding the task and expected outputs.
If you want to learn more prompting tips check out our prompting guide
Chain of Thought
Using Chain of Thought has been shown to improve the quality of the predictions by ~ 10%
from pydantic import BaseModel, Field
from typing import Literal
from openai import OpenAI
import instructor
# Apply the patch to the OpenAI client
# enables response_model keyword
client = instructor.from_openai(OpenAI())
class ClassificationResponse(BaseModel):
"""
A few-shot example of text classification:
Examples:
- "Buy cheap watches now!": SPAM
- "Meeting at 3 PM in the conference room": NOT_SPAM
- "You've won a free iPhone! Click here": SPAM
- "Can you pick up some milk on your way home?": NOT_SPAM
- "Increase your followers by 10000 overnight!": SPAM
"""
chain_of_thought: str = Field(
...,
description="The chain of thought that led to the prediction.",
)
label: Literal["SPAM", "NOT_SPAM"] = Field(
...,
description="The predicted class label.",
)
Classifying Text¶
The function classify
will perform the single-label classification.
def classify(data: str) -> ClassificationResponse:
"""Perform single-label classification on the input text."""
return client.chat.completions.create(
model="gpt-4o-mini",
response_model=ClassificationResponse,
messages=[
{
"role": "user",
"content": f"Classify the following text: <text>{data}</text>",
},
],
)
Testing and Evaluation¶
Let's run examples to see if it correctly identifies spam and non-spam messages.
if __name__ == "__main__":
for text, label in [
("Hey Jason! You're awesome", "NOT_SPAM"),
("I am a nigerian prince and I need your help.", "SPAM"),
]:
prediction = classify(text)
assert prediction.label == label
print(f"Text: {text}, Predicted Label: {prediction.label}")
#> Text: Hey Jason! You're awesome, Predicted Label: NOT_SPAM
#> Text: I am a nigerian prince and I need your help., Predicted Label: SPAM
Multi-Label Classification¶
Defining the Structures¶
For multi-label classification, we'll update our approach to use Literals instead of enums, and include few-shot examples in the model's docstring.
from typing import List
from pydantic import BaseModel, Field
class MultiClassPrediction(BaseModel):
"""
Class for a multi-class label prediction.
Examples:
- "My account is locked": ["TECH_ISSUE"]
- "I can't access my billing info": ["TECH_ISSUE", "BILLING"]
- "When do you close for holidays?": ["GENERAL_QUERY"]
- "My payment didn't go through and now I can't log in": ["BILLING", "TECH_ISSUE"]
"""
chain_of_thought: str = Field(
...,
description="The chain of thought that led to the prediction.",
)
class_labels: List[Literal["TECH_ISSUE", "BILLING", "GENERAL_QUERY"]] = Field(
...,
description="The predicted class labels for the support ticket.",
)
Classifying Text¶
The function multi_classify
is responsible for multi-label classification.
def multi_classify(data: str) -> MultiClassPrediction:
"""Perform multi-label classification on the input text."""
return client.chat.completions.create(
model="gpt-4o-mini",
response_model=MultiClassPrediction,
messages=[
{
"role": "user",
"content": f"Classify the following support ticket: <ticket>{data}</ticket>",
},
],
)
Testing and Evaluation¶
Finally, we test the multi-label classification function using a sample support ticket.
# Test multi-label classification
ticket = "My account is locked and I can't access my billing info."
prediction = multi_classify(ticket)
assert "TECH_ISSUE" in prediction.class_labels
assert "BILLING" in prediction.class_labels
print(f"Ticket: {ticket}")
print(f"Predicted Labels: {prediction.class_labels}")
#> Ticket: My account is locked and I can't access my billing info.
#> Predicted Labels: ['TECH_ISSUE', 'BILLING']
By using Literals and including few-shot examples, we've improved both the single-label and multi-label classification implementations. These changes enhance type safety and provide better guidance for the AI model, potentially leading to more accurate classifications.