DaniilAlpha commited on
Commit
5ee61de
·
1 Parent(s): d55f579

Upload answerer.py

Browse files
Files changed (1) hide show
  1. answerer.py +94 -0
answerer.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Generator, List
2
+ import os, gc
3
+ from huggingface_hub import hf_hub_download
4
+ from rwkv.model import RWKV
5
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
6
+
7
+ ### settings ###
8
+
9
+ ###
10
+
11
+
12
+
13
+ os.environ["RWKV_JIT_ON"] = "1"
14
+ # os.environ["RWKV_CUDA_ON"] = "1" # if "1" then use CUDA kernel for seq mode (much faster)
15
+
16
+ class Answerer:
17
+ def __init__(self, repo: str, filename: str, vocab: str, strategy: str, ctx_limit: int):
18
+ os.environ["RWKV_JIT_ON"] = "1"
19
+ # os.environ["RWKV_CUDA_ON"] = "1"
20
+
21
+ self.__model = RWKV(hf_hub_download(repo, filename), strategy=strategy)
22
+ self.__pipeline = PIPELINE(self.__model, vocab)
23
+ self.ctx_limit = ctx_limit
24
+
25
+ __model: RWKV
26
+ __pipeline: PIPELINE
27
+
28
+ ctx_limit: int
29
+
30
+ def __call__(
31
+ self,
32
+ input: str,
33
+ max_output_length_tk: int,
34
+ chaos = .1,
35
+ repetitiveness = .3,
36
+ diversity = 0,
37
+ _count_penalty = 1,
38
+ ) -> Generator[str, None, None]:
39
+ args = PIPELINE_ARGS(
40
+ temperature=chaos,
41
+ top_p=repetitiveness,
42
+ alpha_frequency=_count_penalty,
43
+ alpha_presence=diversity,
44
+ token_ban = [],
45
+ token_stop = [0],
46
+ )
47
+
48
+ input = input.strip()
49
+
50
+ result: str = ""
51
+
52
+ occurrences: Dict[int, int] = {}
53
+ tokens: List[int] = []
54
+ current_token = None
55
+ state = None
56
+ for _ in range(max_output_length_tk):
57
+ out, state = self.__model.forward(
58
+ [current_token] if current_token else self.__pipeline.encode(input)[-self.ctx_limit:],
59
+ state,
60
+ )
61
+ for token in occurrences:
62
+ out[token] -= args.alpha_presence + occurrences[token] * args.alpha_frequency
63
+
64
+ current_token = self.__pipeline.sample_logits(
65
+ out,
66
+ temperature=args.temperature,
67
+ top_p=args.top_p,
68
+ )
69
+ if current_token in args.token_stop: break
70
+
71
+ tokens.append(current_token)
72
+
73
+ for token in occurrences:
74
+ occurrences[token] *= 0.996
75
+
76
+ if current_token in occurrences:
77
+ occurrences[current_token] += 1
78
+ else:
79
+ occurrences[current_token] = 1
80
+
81
+ tmp = self.__pipeline.decode(tokens)
82
+ if "\ufffd" not in tmp:
83
+ tokens.clear()
84
+ result += tmp
85
+ yield result.strip()
86
+
87
+ tokens.clear()
88
+ occurrences.clear()
89
+ del out, tmp
90
+ del occurrences, tokens, current_token, state
91
+ gc.collect()
92
+
93
+ yield result.strip()
94
+