Spaces:
Runtime error
Runtime error
Fixed normalization
Browse files
app.py
CHANGED
@@ -43,7 +43,7 @@ def waveformer(audio, label_choices):
|
|
43 |
if fs != 44100:
|
44 |
raise ValueError(fs)
|
45 |
mixture = torch.from_numpy(
|
46 |
-
mixture).unsqueeze(0).unsqueeze(0).to(torch.float)
|
47 |
|
48 |
# Construct the query vector
|
49 |
if len(label_choices) == 0:
|
@@ -53,7 +53,7 @@ def waveformer(audio, label_choices):
|
|
53 |
query[0, TARGETS.index(t)] = 1.
|
54 |
|
55 |
with torch.no_grad():
|
56 |
-
output = model(mixture, query)
|
57 |
|
58 |
return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy()
|
59 |
|
|
|
43 |
if fs != 44100:
|
44 |
raise ValueError(fs)
|
45 |
mixture = torch.from_numpy(
|
46 |
+
mixture).unsqueeze(0).unsqueeze(0).to(torch.float) / (2.0 ** 15)
|
47 |
|
48 |
# Construct the query vector
|
49 |
if len(label_choices) == 0:
|
|
|
53 |
query[0, TARGETS.index(t)] = 1.
|
54 |
|
55 |
with torch.no_grad():
|
56 |
+
output = (2.0 ** 15) * model(mixture, query)
|
57 |
|
58 |
return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy()
|
59 |
|