File size: 2,213 Bytes
5ee61de
8de1402
5ee61de
 
1e30c19
5ee61de
 
707a118
5ee61de
 
 
79d23f5
5ee61de
 
 
1e30c19
5ee61de
 
 
 
 
 
 
8de1402
5ee61de
 
 
 
 
 
 
 
 
 
 
 
 
1e30c19
 
5ee61de
 
 
91c3b11
 
 
 
 
 
5ee61de
91c3b11
 
 
 
 
 
5ee61de
91c3b11
5ee61de
91c3b11
 
5ee61de
91c3b11
 
 
 
 
8de1402
91c3b11
 
 
84f3ea8
 
 
4cec5d2
f61e523
5ee61de
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os, gc
from typing import AsyncGenerator
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from asyncio import sleep

class Answerer:
  def __init__(self, model: str, vocab: str, strategy: str, ctx_limit: int):
    os.environ["RWKV_JIT_ON"] = "1"
    # os.environ["RWKV_CUDA_ON"] = "1"
    
    self.__model = RWKV(f"models/{model}.pth", strategy=strategy)
    self.__pipeline = PIPELINE(self.__model, vocab)
    self.ctx_limit = ctx_limit

  async def __call__(
    self,
    input: str,
    max_output_length_tk: int,
    chaos = .1,
    repetitiveness = .3,
    diversity = 0,
    _count_penalty = 1,
  ) -> AsyncGenerator[str, None]:
    args = PIPELINE_ARGS(
      temperature=chaos,
      top_p=repetitiveness,
      alpha_frequency=_count_penalty,
      alpha_presence=diversity,
      token_ban = [],
      token_stop = [0],
    )

    input = input.strip()

    result: str = ""
    
    occurrences: dict[int, int] = {}
    tokens: list[int] = []
    current_token = None
    state = None
    for _ in range(max_output_length_tk):
      out, state = self.__model.forward(
        [current_token] if current_token else self.__pipeline.encode(input)[-self.ctx_limit:],
        state,
      )
      for token in occurrences:
        out[token] -= args.alpha_presence + occurrences[token] * args.alpha_frequency

      current_token = self.__pipeline.sample_logits(
        out,
        temperature=args.temperature,
        top_p=args.top_p,
      )
      if current_token in args.token_stop: break

      tokens.append(current_token)

      for token in occurrences:
        occurrences[token] *= 0.996     

      if current_token in occurrences:
        occurrences[current_token] += 1
      else: 
        occurrences[current_token] = 1
      
      tmp: str = self.__pipeline.decode(tokens)
      if "\ufffd" not in tmp:
        tokens.clear()
        result += tmp
        if result.rstrip().endswith("\n\nUser:"):
          yield result.rstrip().removesuffix("\n\nUser:")
          break
        yield result
        await sleep(.02)

    tokens.clear()
    occurrences.clear()
    del out, tmp
    del occurrences, tokens, current_token, state 
    gc.collect()