{ "cells": [ { "cell_type": "code", "execution_count": 19, "id": "297ea6c7-1eae-47fc-8fa3-d49e6d3deb6c", "metadata": {}, "outputs": [], "source": [ "import json\n", "import torch\n", "from transformers import AutoModelForSequenceClassification, AutoTokenizer" ] }, { "cell_type": "code", "execution_count": 4, "id": "45ccf708-2a0b-43a1-bf5f-45294ab205d4", "metadata": {}, "outputs": [], "source": [ "with open(\"twiz-data/all_intents.json\", 'r') as json_in:\n", " data = json.load(json_in)" ] }, { "cell_type": "code", "execution_count": 7, "id": "d9875b16-36f8-4289-9ddf-6907f74a975c", "metadata": {}, "outputs": [], "source": [ "id_to_intent, intent_to_id = dict(), dict()\n", "for i, intent in enumerate(data):\n", " id_to_intent[i] = intent\n", " intent_to_id[intent] = i" ] }, { "cell_type": "code", "execution_count": 13, "id": "01a87f85-e4d7-454c-b645-bf252161d458", "metadata": {}, "outputs": [], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(\"roberta-based/checkpoint-925\", num_labels=len(data), id2label=id_to_intent, label2id=intent_to_id)\n", "tokenizer = AutoTokenizer.from_pretrained(\"tokenizer\")" ] }, { "cell_type": "code", "execution_count": 21, "id": "f29489cf-fa4b-453e-8922-6e972db1cc7c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NextStepIntent\n" ] } ], "source": [ "model_in = tokenizer(\"I really really wanna go to the next step\", return_tensors='pt')\n", "with torch.no_grad():\n", " logits = model(**model_in).logits\n", " predicted_class_id = logits.argmax().item()\n", " print(model.config.id2label[predicted_class_id])\n" ] } ], "metadata": { "kernelspec": { "display_name": "ws2024", "language": "python", "name": "ws2024" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }