zfzhang-thu commited on
Commit
587ae20
·
1 Parent(s): 0cfc205

using bf16

Browse files
Files changed (1) hide show
  1. leo/model.py +3 -3
leo/model.py CHANGED
@@ -11,7 +11,7 @@ from leo.grounding_head import SequentialGroundHead
11
  from leo.utils import get_mlp_head
12
 
13
 
14
- def maybe_autocast(model, dtype='float32', enabled=True): ### not-half mode
15
  # if on cpu, don't use autocast
16
  # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
17
  enable_autocast = model.device != torch.device('cpu')
@@ -75,7 +75,7 @@ class SequentialGrounder(torch.nn.Module):
75
  if 'vicuna' in llm_name.lower():
76
  self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side)
77
  self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
78
- self.llm_model = LlamaForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float32) # not-half mode torch_dtype=torch.float16
79
  self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
80
  else:
81
  self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side)
@@ -320,7 +320,7 @@ class SequentialGrounder(torch.nn.Module):
320
 
321
  with maybe_autocast(self):
322
  outputs = self.llm_model(
323
- inputs_embeds=inputs_embeds.float(), # not-half mode
324
  attention_mask=attention_mask,
325
  return_dict=True,
326
  output_hidden_states=True,
 
11
  from leo.utils import get_mlp_head
12
 
13
 
14
+ def maybe_autocast(model, dtype='bf16', enabled=True):
15
  # if on cpu, don't use autocast
16
  # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
17
  enable_autocast = model.device != torch.device('cpu')
 
75
  if 'vicuna' in llm_name.lower():
76
  self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side)
77
  self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
78
+ self.llm_model = LlamaForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float16)
79
  self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
80
  else:
81
  self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side)
 
320
 
321
  with maybe_autocast(self):
322
  outputs = self.llm_model(
323
+ inputs_embeds=inputs_embeds,
324
  attention_mask=attention_mask,
325
  return_dict=True,
326
  output_hidden_states=True,