{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### RoBERTa based Spam Message Detection" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from detector import SpamMessageDetector" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Training roberta-spam model: to start training, set TRAIN=True, you may skip for Demo" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "TRAIN = False\n", "if TRAIN:\n", " spam_detector = SpamMessageDetector(\"roberta-base\", max_length=512, seed=0)\n", " train_data_path = 'spam_message_train.csv'\n", " val_data_path = 'spam_message_val.csv'\n", " spam_detector.train(train_data_path, val_data_path, num_epochs=10, batch_size=32, learning_rate=2e-5)\n", " model_path = 'roberta-spam'\n", " spam_detector.save_model(model_path)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Training Results\n", "\n", "Loss | Accuracy | Precision / Recall \n", ":-------------------------:|:-------------------------:|:-------------------------: \n", "![](plots/train_validation_loss.jpg \"Train / Validation Loss\") Train / Validation | ![](plots/validation_accuracy.jpg \"Validation Accuracy\") Validation | ![](plots/validation_precision_recall.jpg \"Validation Precision / Recall\") Validation\n", "\n", "\n", "\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Evaluating the roberta-spam model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "spam_detector = SpamMessageDetector(\"mshenoda/roberta-spam\")\n", "spam_detector.evaluate(\"data/spam_message_test.csv\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Testing individual example messages" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'spam_detector' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m message1 \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHey so this sat are we going for the intro pilates only? Or the kickboxing too?\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m detection \u001b[38;5;241m=\u001b[39m \u001b[43mspam_detector\u001b[49m\u001b[38;5;241m.\u001b[39mdetect(message1)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mExample 1\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInput Message: \u001b[39m\u001b[38;5;124m\"\u001b[39m, message1)\n", "\u001b[0;31mNameError\u001b[0m: name 'spam_detector' is not defined" ] } ], "source": [ "message1 = \"Hey so this sat are we going for the intro pilates only? Or the kickboxing too?\"\n", "detection = spam_detector.detect(message1)\n", "\n", "print(\"\\nExample 1\")\n", "print(\"Input Message: \", message1)\n", "print(\"Detected Spam?: \", bool(detection))\n", "\n", "message2 = \"U have a secret admirer. REVEAL who thinks U R So special. Call 09065174042. To opt out Reply REVEAL STOP. 1.50 per msg recd.\"\n", "detection = spam_detector.detect(message2)\n", "\n", "print(\"\\nExample 2\")\n", "print(\"Input Message: \", message2)\n", "print(\"Detected Spam: \", bool(detection))\n", "\n", "message3 = \"Dude im no longer a pisces. Im an aquarius now.\"\n", "detection = spam_detector.detect(message3)\n", "\n", "print(\"\\nExample 3\")\n", "print(\"Input Message: \", message3)\n", "print(\"Detected Spam?: \", bool(detection))\n", "\n", "message4 = \"Great News! Call FREEFONE 08006344447 to claim your guaranteed $1000 CASH or $2000 gift. Speak to a live operator NOW!\"\n", "detection = spam_detector.detect(message4)\n", "\n", "print(\"\\nExample 4 \")\n", "print(\"Input Message: \", message4)\n", "print(\"Detected Spam?: \", bool(detection))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Batch Processing is supported for processing multiple messages at once" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "messages = [message1, message2, message3, message4]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Example 1\n", "Input Message: Hey so this sat are we going for the intro pilates only? Or the kickboxing too?\n", "detected spam: False\n", "\n", "Example 2\n", "Input Message: U have a secret admirer. REVEAL who thinks U R So special. Call 09065174042. To opt out Reply REVEAL STOP. 1.50 per msg recd.\n", "detected spam: True\n", "\n", "Example 3\n", "Input Message: Dude im no longer a pisces. Im an aquarius now.\n", "detected spam: False\n", "\n", "Example 4\n", "Input Message: Great News! Call FREEFONE 08006344447 to claim your guaranteed $1000 CASH or $2000 gift. Speak to a live operator NOW!\n", "detected spam: True\n" ] } ], "source": [ "\n", "detections = spam_detector.detect(messages)\n", "for i, message in enumerate(messages):\n", " print(\"\\nExample \", f\"{i+1}\")\n", " print(\"Input Message: \", message)\n", " print(\"detected spam: \", bool(detections[i]))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }