T5 for Search Query Generation

class T5ForSQG:
  def __init__(self, model_path):
    self.model = T5ForConditionalGeneration.from_pretrained(model_path)
    self.tokenizer = T5Tokenizer.from_pretrained(model_path)

  def make_queries(self, topic, n=1, device='cpu', batch_size=16):
    ds = YourDataSetClass(pd.DataFrame({'topic': ['make queries: '+topic]*n, 'queries': [[]*n]}, index=range(n)), self.tokenizer, 64, 64, 'topic', 'queries')
    
    loader_params = {'batch_size': n if n < batch_size else batch_size, 'shuffle': False, 'num_workers': 0}

    loader = DataLoader(ds, **loader_params)

    self.model.eval()

    predictions = []
    with torch.no_grad():
        for _, data in enumerate(loader, 0):
            y = data['target_ids'].to(device, dtype = torch.long)
            ids = data['source_ids'].to(device, dtype = torch.long)
            mask = data['source_mask'].to(device, dtype = torch.long)

            generated_ids = self.model.generate(
                input_ids = ids,
                attention_mask = mask,
                max_length=64,
                num_beams=1,
                repetition_penalty=2.5,
                length_penalty=1.0,
                do_sample = True,
                temperature = 1.5,
                top_k = 10,
                top_p = 0.95
                )
            
            preds = list(set([self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]))
            predictions.extend(preds)
    
    return list(set(predictions))
Downloads last month
113
Safetensors
Model size
223M params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and the model is not deployed on the HF Inference API.

Collection including 1rsh/t5-base-search-query-generation