bilal6913 commited on
Commit
0b7d904
·
verified ·
1 Parent(s): af46c2d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, render_template
2
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
+ import torch
4
+ import torchaudio
5
+ import os
6
+
7
+ # Initialize Flask app
8
+ app = Flask(__name__)
9
+
10
+ # Load the model and processor
11
+ model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-arabic"
12
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
13
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
14
+
15
+ # Define the upload folder
16
+ UPLOAD_FOLDER = 'uploads'
17
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
18
+
19
+ # Ensure the upload folder exists
20
+ if not os.path.exists(UPLOAD_FOLDER):
21
+ os.makedirs(UPLOAD_FOLDER)
22
+
23
+ @app.route('/')
24
+ def index():
25
+ return render_template('index.html')
26
+
27
+ @app.route('/transcribe', methods=['POST'])
28
+ def transcribe_audio():
29
+ if 'file' not in request.files:
30
+ return jsonify({'error': 'No file part'}), 400
31
+
32
+ file = request.files['file']
33
+
34
+ if file.filename == '':
35
+ return jsonify({'error': 'No selected file'}), 400
36
+
37
+ if file:
38
+ # Save the uploaded file
39
+ file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
40
+ file.save(file_path)
41
+
42
+ # Load the audio file
43
+ speech_array, sampling_rate = torchaudio.load(file_path)
44
+ speech_array = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(speech_array)
45
+
46
+ # Process the audio input
47
+ input_values = processor(speech_array.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values
48
+
49
+ # Perform inference
50
+ with torch.no_grad():
51
+ logits = model(input_values).logits
52
+
53
+ # Get the predicted transcription
54
+ predicted_ids = torch.argmax(logits, dim=-1)
55
+ transcription = processor.batch_decode(predicted_ids)
56
+
57
+ return jsonify({'transcription': transcription[0]})
58
+
59
+ return jsonify({'error': 'Something went wrong!'}), 500
60
+
61
+ if __name__ == '__main__':
62
+ app.run(debug=True)
63
+