dgaff commited on
Commit
cb65d6c
·
verified ·
1 Parent(s): ac6a894

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +83 -1
README.md CHANGED
@@ -2,4 +2,86 @@
2
  license: mit
3
  base_model:
4
  - distilbert/distilbert-base-uncased
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: mit
3
  base_model:
4
  - distilbert/distilbert-base-uncased
5
+ ---
6
+
7
+ Deepest apologies for how fucked up this is, but:
8
+
9
+ ```
10
+ import os
11
+ import sys
12
+ import json
13
+ import torch
14
+ from huggingface_hub import hf_hub_download
15
+ import importlib.util
16
+
17
+ # Repository ID and filenames
18
+ repo_id = "dgaff/bsky_user_classifier"
19
+ files_to_download = {
20
+ "model_weights": "multioutput_regressor.pth",
21
+ "train_script": "train.py",
22
+ "data_processing": "data_processing.py",
23
+ "utils": "utils.py",
24
+ "label_mappings": "label_mappings.json",
25
+ }
26
+
27
+ # Download necessary files
28
+ model_weights_path = hf_hub_download(repo_id=repo_id, filename=files_to_download["model_weights"])
29
+ train_script_path = hf_hub_download(repo_id=repo_id, filename=files_to_download["train_script"])
30
+ data_processing_path = hf_hub_download(repo_id=repo_id, filename=files_to_download["data_processing"])
31
+ util_path = hf_hub_download(repo_id=repo_id, filename=files_to_download["utils"])
32
+ label_mappings_path = hf_hub_download(repo_id=repo_id, filename=files_to_download["label_mappings"])
33
+
34
+ # Update sys.path to include dependencies
35
+ for path in [data_processing_path, util_path]:
36
+ dir_path = os.path.dirname(path)
37
+ if dir_path not in sys.path:
38
+ sys.path.append(dir_path)
39
+
40
+ # Load train.py as a module
41
+ spec = importlib.util.spec_from_file_location("train_module", train_script_path)
42
+ train_module = importlib.util.module_from_spec(spec)
43
+ sys.modules["train_module"] = train_module
44
+ spec.loader.exec_module(train_module)
45
+
46
+ # Load label mappings
47
+ with open(label_mappings_path) as f:
48
+ label_mappings = json.load(f)
49
+
50
+ # Initialize the model
51
+ hidden_size = 768 # Ensure this matches your model's configuration
52
+ num_outputs = 23 # Update if different
53
+ model = train_module.MultiOutputRegressor(hidden_size=hidden_size, num_outputs=num_outputs)
54
+
55
+ # Load weights and set model to evaluation mode
56
+ model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu')))
57
+ model.eval()
58
+
59
+ # Set device
60
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
61
+ model.to(device)
62
+
63
+ # Prepare input sentences and generate embeddings
64
+ new_sentences = [
65
+ "This is a test sentence.",
66
+ "Another example of a sentence to predict."
67
+ ]
68
+ embedder = train_module.EmbeddingGenerator()
69
+ new_embeddings = embedder.generate_embeddings(new_sentences)
70
+ new_embeddings_tensor = torch.tensor(new_embeddings, dtype=torch.float).to(device)
71
+
72
+ # Generate predictions
73
+ with torch.no_grad():
74
+ predictions = model(new_embeddings_tensor).cpu().numpy()
75
+
76
+ # Map predictions to labels and print results
77
+ for sentence, pred in zip(new_sentences, predictions):
78
+ label_pred_dict = {label_mappings["id2label"][str(i)]: float(pred[i]) for i in range(len(pred))}
79
+ print(f"Sentence: {sentence}")
80
+ print("Predictions:")
81
+ for label, value in label_pred_dict.items():
82
+ print(f" {label}: {value}")
83
+ print()
84
+
85
+ ```
86
+
87
+ I'll do better next time