席亚东 commited on
Commit
16bf127
·
1 Parent(s): ef2abea

fix the bug in inference.py

Browse files
Files changed (1) hide show
  1. inference.py +7 -2
inference.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  from torch.nn.utils.rnn import pad_sequence
8
 
9
  from fairseq import checkpoint_utils, options, tasks, utils
 
10
 
11
  Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
12
 
@@ -77,7 +78,12 @@ class Inference(object):
77
  use_cuda = torch.cuda.is_available() and not args.cpu
78
  self.use_cuda = use_cuda
79
 
80
- # Optimize ensemble for generation
 
 
 
 
 
81
  state = torch.load(args.path, map_location=torch.device("cpu"))
82
  cfg_args = eval(str(state["cfg"]))["model"]
83
  del cfg_args["_name"]
@@ -97,7 +103,6 @@ class Inference(object):
97
  "max_batch":eet_batch_size,
98
  "full_seq_len":eet_seq_len}
99
  print(model_args)
100
- from eet.fairseq.transformer import EETTransformerDecoder
101
  eet_model = EETTransformerDecoder.from_fairseq_pretrained(model_id_or_path = args.path,
102
  dictionary = self.src_dict,args=model_args,
103
  config = eet_config,
 
7
  from torch.nn.utils.rnn import pad_sequence
8
 
9
  from fairseq import checkpoint_utils, options, tasks, utils
10
+ from eet.fairseq.transformer import EETTransformerDecoder
11
 
12
  Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
13
 
 
78
  use_cuda = torch.cuda.is_available() and not args.cpu
79
  self.use_cuda = use_cuda
80
 
81
+ model_path = args.path
82
+ checkpoint = torch.load(model_path.replace("best.pt", "best_part_1.pt"))
83
+ checkpoint["model"].update(torch.load(model_path.replace("best.pt", "best_part_2.pt")))
84
+ checkpoint["model"].update(torch.load(model_path.replace("best.pt", "best_part_3.pt")))
85
+ torch.save(checkpoint, model_path)
86
+
87
  state = torch.load(args.path, map_location=torch.device("cpu"))
88
  cfg_args = eval(str(state["cfg"]))["model"]
89
  del cfg_args["_name"]
 
103
  "max_batch":eet_batch_size,
104
  "full_seq_len":eet_seq_len}
105
  print(model_args)
 
106
  eet_model = EETTransformerDecoder.from_fairseq_pretrained(model_id_or_path = args.path,
107
  dictionary = self.src_dict,args=model_args,
108
  config = eet_config,