File size: 2,180 Bytes
b313f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import pytest
from app import app
from helper_functions import predict_class, transform_list_of_texts, prepare_text, inference
import torch
from transformers import DistilBertForSequenceClassification, AutoTokenizer

@pytest.fixture
def client():
    app.config['TESTING'] = True
    with app.test_client() as client:
        yield client

# Unit tests

def test_predict_class():
    # Mock the model and tokenizer
    model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
    text = ["This is a sample text for testing."]
    
    predicted_class, class_probabilities = predict_class(text, model)
    
    assert isinstance(predicted_class, tuple)
    assert isinstance(class_probabilities, dict)
    assert len(class_probabilities) == 17  # Assuming 17 classes

def test_transform_list_of_texts():
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    texts = ["This is a sample text.", "Another sample text."]
    
    result = transform_list_of_texts(texts, tokenizer, 510, 510, 1, 2550)
    
    assert isinstance(result, dict)
    assert "input_ids" in result
    assert "attention_mask" in result

# Integration tests

def test_pdf_upload(client):
    # You'll need to create a sample PDF file for testing
    with open('sample.pdf', 'rb') as pdf_file:
        data = {'file': (pdf_file, 'sample.pdf')}
        response = client.post('/pdf/upload', data=data, content_type='multipart/form-data')
    
    assert response.status_code == 200
    assert b'class_probabilities' in response.data

def test_sentence_endpoint(client):
    data = {'text': 'This is a sample sentence for testing.'}
    response = client.post('/sentence', data=data)
    
    assert response.status_code == 200
    assert b'predicted_class' in response.data

def test_voice_endpoint(client):
    # You'll need to create a sample audio file for testing
    with open('sample_audio.wav', 'rb') as audio_file:
        data = {'audio': (audio_file, 'sample_audio.wav')}
        response = client.post('/voice', data=data, content_type='multipart/form-data')
    
    assert response.status_code == 200
    assert b'extracted_text' in response.data