Update pipeline.py
Browse files- pipeline.py +1 -1
pipeline.py
CHANGED
@@ -34,7 +34,7 @@ class PretrainedPipeline():
|
|
34 |
def __init__(self):
|
35 |
self.device = torch.device("cpu")
|
36 |
self.generator = Generator() # Instantiate your GAN generator class
|
37 |
-
self.generator.load_state_dict(torch.load("
|
38 |
self.generator.eval()
|
39 |
|
40 |
def generate_image(self):
|
|
|
34 |
def __init__(self):
|
35 |
self.device = torch.device("cpu")
|
36 |
self.generator = Generator() # Instantiate your GAN generator class
|
37 |
+
self.generator.load_state_dict(torch.load("generator.pth", map_location=self.device))
|
38 |
self.generator.eval()
|
39 |
|
40 |
def generate_image(self):
|