Building a Spam Detector with Naive Bayes Classifier
Detecting and Filtering Unwanted Messages: Building a Spam Detector with Naive Bayes Classifier
Spam detection is a critical task in natural language processing (NLP) and machine learning, aimed at identifying and filtering out unwanted or malicious messages. In this tutorial, we'll create a simple spam detector using the Naive Bayes classifier with Python's scikit-learn library.
Introduction
Spam detection plays a crucial role in email systems, messaging apps, and other communication platforms, as it helps users avoid irrelevant or harmful messages. Machine learning models, particularly the Naive Bayes classifier, are widely used for spam detection due to their simplicity and effectiveness.
Dataset Exploration
We'll start by loading the SMS Spam Collection dataset, which contains SMS messages labeled as 'spam' or 'ham' (not spam). Let's explore the dataset and visualize the distribution of spam and ham messages:
import pandas as pd
import matplotlib.pyplot as plt
# Load the dataset
df = pd.read_csv('SMSSpamCollection', sep='\t', names=['label', 'message'])
# Convert labels to binary (0 for ham, 1 for spam)
df['label'] = df['label'].map({'ham': 0, 'spam': 1})
# Visualize the distribution of spam and ham messages
spam_count = df['label'].sum()
ham_count = len(df) - spam_count
plt.figure(figsize=(6, 6))
plt.pie([ham_count, spam_count], labels=['Ham', 'Spam'], autopct='%1.1f%%', startangle=90)
plt.axis('equal')
plt.title('Distribution of Spam and Ham Messages')
plt.show()
The dataset contains a total of X messages, with Y% of them labeled as spam and Z% labeled as ham.
Data Preprocessing
Before training the model, we need to preprocess the text data. This involves removing punctuation, converting text to lowercase, and removing stopwords (common words that do not contribute much to the meaning of the text):
import string
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
# Download stopwords if not already downloaded
nltk.download('stopwords')
nltk.download('punkt')
def preprocess_text(text):
text = text.translate(str.maketrans('', '', string.punctuation)) # Remove punctuation
text = text.lower() # Convert text to lowercase
stop_words = set(stopwords.words('english')) # Get English stopwords
word_tokens = word_tokenize(text) # Tokenize text into words
filtered_text = [word for word in word_tokens if word not in stop_words] # Remove stopwords
return ' '.join(filtered_text)
df['message'] = df['message'].apply(preprocess_text)
After preprocessing, each message in the dataset is cleaned and ready for further processing.
Feature Extraction
To train our model, we need to convert the text data into numerical features. We'll use the Bag-of-Words (BoW) model, which represents each message as a vector of word counts:
from sklearn.feature_extraction.text import CountVectorizer
count_vectorizer = CountVectorizer()
X = count_vectorizer.fit_transform(df['message'])
y = df['label']
The CountVectorizer
converts a collection of text documents into a matrix of token counts, where each row represents a document and each column represents a unique word in the corpus.
Model Training and Evaluation
Next, we'll split the dataset into training and testing sets, and train a Multinomial Naive Bayes classifier:
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
nb_classifier = MultinomialNB()
nb_classifier.fit(X_train, y_train)
y_pred = nb_classifier.predict(X_test)
# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
# Display evaluation metrics
print("Accuracy: {:.2f}%".format(accuracy * 100))
print("Precision: {:.2f}%".format(precision * 100))
print("Recall: {:.2f}%".format(recall * 100))
print("\nConfusion Matrix:")
print(conf_matrix)
Conclusion
Happy coding!
Give this code a try and share your thoughts in comments.