File size: 2,023 Bytes
5ee61de
 
 
1e30c19
5ee61de
 
707a118
5ee61de
 
 
707a118
5ee61de
 
 
1e30c19
5ee61de
 
 
 
 
 
 
7e854d5
5ee61de
 
 
 
 
 
 
 
 
 
 
 
 
1e30c19
 
5ee61de
 
 
91c3b11
 
 
 
 
 
5ee61de
91c3b11
 
 
 
 
 
5ee61de
91c3b11
5ee61de
91c3b11
 
5ee61de
91c3b11
 
 
 
 
 
 
 
 
4cec5d2
f61e523
5ee61de
 
 
 
 
 
 
4cec5d2
5ee61de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os, gc
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from asyncio import sleep

class Answerer:
  def __init__(self, model: str, vocab: str, strategy: str, ctx_limit: int):
    os.environ["RWKV_JIT_ON"] = "1"
    # os.environ["RWKV_CUDA_ON"] = "1"
    
    self.__model = RWKV(model, strategy=strategy)
    self.__pipeline = PIPELINE(self.__model, vocab)
    self.ctx_limit = ctx_limit

  async def __call__(
    self,
    input: str,
    max_output_length_tk: int,
    chaos = .1,
    repetitiveness = .3,
    diversity = 0,
    _count_penalty = 1,
  ):
    args = PIPELINE_ARGS(
      temperature=chaos,
      top_p=repetitiveness,
      alpha_frequency=_count_penalty,
      alpha_presence=diversity,
      token_ban = [],
      token_stop = [0],
    )

    input = input.strip()

    result: str = ""
    
    occurrences: dict[int, int] = {}
    tokens: list[int] = []
    current_token = None
    state = None
    for _ in range(max_output_length_tk):
      out, state = self.__model.forward(
        [current_token] if current_token else self.__pipeline.encode(input)[-self.ctx_limit:],
        state,
      )
      for token in occurrences:
        out[token] -= args.alpha_presence + occurrences[token] * args.alpha_frequency

      current_token = self.__pipeline.sample_logits(
        out,
        temperature=args.temperature,
        top_p=args.top_p,
      )
      if current_token in args.token_stop: break

      tokens.append(current_token)

      for token in occurrences:
        occurrences[token] *= 0.996     

      if current_token in occurrences:
        occurrences[current_token] += 1
      else: 
        occurrences[current_token] = 1
      
      tmp = self.__pipeline.decode(tokens)
      if "\ufffd" not in tmp:
        tokens.clear()
        result += tmp
        yield result
        await sleep(.02)

    tokens.clear()
    occurrences.clear()
    del out, tmp
    del occurrences, tokens, current_token, state 
    gc.collect()

    yield result