ruDialoGPT-medium / README.md
solemn-leader's picture
Update README.md
d9f55c5
|
raw
history blame
2.46 kB
metadata
license: mit
pipeline_tag: text-generation
widget:
  - text: '@@ПЕРВЫЙ@@ привет @@ВТОРОЙ@@ привет @@ПЕРВЫЙ@@ как дела? @@ВТОРОЙ@@'
    example_title: how r u
  - text: '@@ПЕРВЫЙ@@ что ты делал на выходных? @@ВТОРОЙ@@'
    example_title: wyd
language:
  - ru
tags:
  - conversational

This generation model is based on sberbank-ai/rugpt3medium_based_on_gpt2. It's trained on large corpus of dialog data and can be used for buildning generative conversational agents

The model was trained with context size 3

On a private validation set we calculated metrics introduced in this paper:

  • Sensibleness: Crowdsourcers were asked whether model's response makes sense given the context
  • Specificity: Crowdsourcers were asked whether model's response is specific for given context, in other words we don't want our model to give general and boring responses
  • SSA which is the average of two metrics above (Sensibleness Specificity Average)
sensibleness specificity SSA
tinkoff-ai/ruDialoGPT-small 0.64 0.5 0.57
tinkoff-ai/ruDialoGPT-medium 0.78 0.69 0.735

How to use:

import torch
from transformers import AutoTokenizer, AutoModelWithLMHead

tokenizer = AutoTokenizer.from_pretrained('tinkoff-ai/ruDialoGPT-medium')
model = AutoModelWithLMHead.from_pretrained('tinkoff-ai/ruDialoGPT-medium')
inputs = tokenizer('@@ПЕРВЫЙ@@ привет @@ВТОРОЙ@@ привет @@ПЕРВЫЙ@@ как дела? @@ВТОРОЙ@@', return_tensors='pt')
with torch.inference_mode():
    generated_token_ids = model.generate(
        **inputs,
        top_k=10,
        top_p=0.95,
        num_beams=3,
        num_return_sequences=3,
        do_sample=True,
        no_repeat_ngram_size=2,
        temperature=1.2,
        repetition_penalty=1.2,
        length_penalty=1.0,
        eos_token_id=50257
    )
    context_with_response = [tokenizer.decode(sample_token_ids) for sample_token_ids in generated_token_ids]
context_with_response