Spaces:
Paused
Paused
File size: 2,213 Bytes
5ee61de 8de1402 5ee61de 1e30c19 5ee61de 707a118 5ee61de 79d23f5 5ee61de 1e30c19 5ee61de 8de1402 5ee61de 1e30c19 5ee61de 91c3b11 5ee61de 91c3b11 5ee61de 91c3b11 5ee61de 91c3b11 5ee61de 91c3b11 8de1402 91c3b11 84f3ea8 4cec5d2 f61e523 5ee61de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os, gc
from typing import AsyncGenerator
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from asyncio import sleep
class Answerer:
def __init__(self, model: str, vocab: str, strategy: str, ctx_limit: int):
os.environ["RWKV_JIT_ON"] = "1"
# os.environ["RWKV_CUDA_ON"] = "1"
self.__model = RWKV(f"models/{model}.pth", strategy=strategy)
self.__pipeline = PIPELINE(self.__model, vocab)
self.ctx_limit = ctx_limit
async def __call__(
self,
input: str,
max_output_length_tk: int,
chaos = .1,
repetitiveness = .3,
diversity = 0,
_count_penalty = 1,
) -> AsyncGenerator[str, None]:
args = PIPELINE_ARGS(
temperature=chaos,
top_p=repetitiveness,
alpha_frequency=_count_penalty,
alpha_presence=diversity,
token_ban = [],
token_stop = [0],
)
input = input.strip()
result: str = ""
occurrences: dict[int, int] = {}
tokens: list[int] = []
current_token = None
state = None
for _ in range(max_output_length_tk):
out, state = self.__model.forward(
[current_token] if current_token else self.__pipeline.encode(input)[-self.ctx_limit:],
state,
)
for token in occurrences:
out[token] -= args.alpha_presence + occurrences[token] * args.alpha_frequency
current_token = self.__pipeline.sample_logits(
out,
temperature=args.temperature,
top_p=args.top_p,
)
if current_token in args.token_stop: break
tokens.append(current_token)
for token in occurrences:
occurrences[token] *= 0.996
if current_token in occurrences:
occurrences[current_token] += 1
else:
occurrences[current_token] = 1
tmp: str = self.__pipeline.decode(tokens)
if "\ufffd" not in tmp:
tokens.clear()
result += tmp
if result.rstrip().endswith("\n\nUser:"):
yield result.rstrip().removesuffix("\n\nUser:")
break
yield result
await sleep(.02)
tokens.clear()
occurrences.clear()
del out, tmp
del occurrences, tokens, current_token, state
gc.collect()
|