martin-gorner commited on
Commit
285adea
·
1 Parent(s): a2b7758

more bug fixes

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. models.py +1 -3
app.py CHANGED
@@ -24,7 +24,7 @@ for preset in model_presets:
24
  chat_template = get_appropriate_chat_template(preset)
25
  chat_state = ChatState(model, "", chat_template)
26
  prompt, response = chat_state.send_message("Hello")
27
- print("model " + preset + "loaded and initialized.")
28
  print("The model responded: " + response)
29
  models.append(model)
30
 
 
24
  chat_template = get_appropriate_chat_template(preset)
25
  chat_state = ChatState(model, "", chat_template)
26
  prompt, response = chat_state.send_message("Hello")
27
+ print("model " + preset + " loaded and initialized.")
28
  print("The model responded: " + response)
29
  models.append(model)
30
 
models.py CHANGED
@@ -41,10 +41,8 @@ def get_default_layout_map(preset_name, device_mesh):
41
  def log_applied_layout_map(model):
42
  if "Gemma" in type(model).__name__:
43
  transformer_decoder_block_name = "decoder_block_1"
44
- elif "Llama3" in type(model).__name__ or "Mistral" in type(model).__name__:
45
  transformer_decoder_block_name = "transformer_layer_1"
46
- else:
47
- assert (0, "Model type not recognized. Cannot display model layout.")
48
 
49
  # See how layer sharding was applied
50
  embedding_layer = model.backbone.get_layer("token_embedding")
 
41
  def log_applied_layout_map(model):
42
  if "Gemma" in type(model).__name__:
43
  transformer_decoder_block_name = "decoder_block_1"
44
+ else: # works for Llama, Mistral, Vicuna
45
  transformer_decoder_block_name = "transformer_layer_1"
 
 
46
 
47
  # See how layer sharding was applied
48
  embedding_layer = model.backbone.get_layer("token_embedding")