Spaces:
Sleeping
Sleeping
cryptocalypse
commited on
Update psychohistory.py
Browse files- psychohistory.py +76 -46
psychohistory.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
import matplotlib.pyplot as plt
|
2 |
from mpl_toolkits.mplot3d import Axes3D
|
3 |
import networkx as nx
|
4 |
-
import random
|
5 |
import numpy as np
|
6 |
import json
|
7 |
import sys
|
|
|
8 |
|
9 |
def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G, parent=None, node_count_per_depth=None):
|
10 |
-
"""Generates a tree of nodes with positions adjusted on the x-axis, and
|
11 |
if node_count_per_depth is None:
|
12 |
node_count_per_depth = {}
|
13 |
-
|
14 |
if depth not in node_count_per_depth:
|
15 |
node_count_per_depth[depth] = 0
|
16 |
|
@@ -19,13 +19,13 @@ def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G,
|
|
19 |
|
20 |
num_children = random.randint(1, max_nodes)
|
21 |
x_positions = [current_x + i * x_range / (num_children + 1) for i in range(num_children)]
|
22 |
-
|
23 |
for x in x_positions:
|
24 |
# Add node to the graph
|
25 |
node_id = len(G.nodes)
|
26 |
node_count_per_depth[depth] += 1
|
27 |
prob = random.uniform(0, 1) # Assign random probability
|
28 |
-
G.add_node(node_id, pos=(x, prob, depth)) # Use `depth` for
|
29 |
if parent is not None:
|
30 |
G.add_edge(parent, node_id)
|
31 |
# Recursively add child nodes
|
@@ -33,27 +33,39 @@ def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G,
|
|
33 |
|
34 |
return node_count_per_depth
|
35 |
|
|
|
|
|
36 |
def build_graph_from_json(json_data, G):
|
37 |
"""Builds a graph from JSON data."""
|
38 |
-
def add_event(parent_id, event_data,
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
44 |
G.add_edge(parent_id, node_id)
|
45 |
-
# Add child events
|
46 |
-
add_event(node_id, {'events': value}, key)
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
data = json.loads(json_data)
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
def find_paths(G):
|
56 |
-
"""Finds the paths with the highest and lowest average probability, and the
|
57 |
best_path = None
|
58 |
worst_path = None
|
59 |
longest_duration_path = None
|
@@ -72,30 +84,30 @@ def find_paths(G):
|
|
72 |
if not all('pos' in G.nodes[node] for node in path):
|
73 |
continue # Skip paths with nodes missing the 'pos' attribute
|
74 |
|
75 |
-
# Calculate the
|
76 |
-
probabilities = [G.nodes[node]['pos'][1] for node in path] # Get probabilities
|
77 |
mean_prob = np.mean(probabilities)
|
78 |
|
79 |
-
# Evaluate
|
80 |
if mean_prob > best_mean_prob:
|
81 |
best_mean_prob = mean_prob
|
82 |
best_path = path
|
83 |
|
84 |
-
# Evaluate
|
85 |
if mean_prob < worst_mean_prob:
|
86 |
worst_mean_prob = mean_prob
|
87 |
worst_path = path
|
88 |
|
89 |
-
# Calculate
|
90 |
x_positions = [G.nodes[node]['pos'][0] for node in path]
|
91 |
duration = max(x_positions) - min(x_positions)
|
92 |
|
93 |
-
# Evaluate
|
94 |
if duration > max_duration:
|
95 |
max_duration = duration
|
96 |
longest_duration_path = path
|
97 |
|
98 |
-
# Evaluate
|
99 |
if duration < min_duration:
|
100 |
min_duration = duration
|
101 |
shortest_duration_path = path
|
@@ -115,7 +127,7 @@ def draw_path_3d(G, path, filename='path_plot_3d.png', highlight_color='blue'):
|
|
115 |
fig = plt.figure(figsize=(16, 12))
|
116 |
ax = fig.add_subplot(111, projection='3d')
|
117 |
|
118 |
-
# Assign colors to
|
119 |
node_colors = []
|
120 |
for node in path:
|
121 |
prob = G.nodes[node]['pos'][1]
|
@@ -135,31 +147,38 @@ def draw_path_3d(G, path, filename='path_plot_3d.png', highlight_color='blue'):
|
|
135 |
x_end, y_end, z_end = pos[edge[1]]
|
136 |
ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color=highlight_color, lw=2)
|
137 |
|
138 |
-
# Add labels to
|
139 |
for node, (x, y, z) in pos.items():
|
140 |
if node in path:
|
141 |
ax.text(x, y, z, str(node), fontsize=12, color='black')
|
142 |
|
143 |
-
#
|
144 |
ax.set_xlabel('Time (weeks)')
|
145 |
ax.set_ylabel('Event Probability')
|
146 |
ax.set_zlabel('Event Number')
|
147 |
-
ax.set_title('Event Tree
|
|
|
|
|
|
|
148 |
|
149 |
-
plt.savefig(filename, bbox_inches='tight') # Save to a file with adjusted margins
|
150 |
-
plt.close() # Close the figure to free up resources
|
151 |
|
152 |
def draw_global_tree_3d(G, filename='global_tree.png'):
|
153 |
"""Draws the entire graph in 3D using networkx and matplotlib and saves the figure to a file."""
|
154 |
pos = nx.get_node_attributes(G, 'pos')
|
|
|
155 |
|
|
|
|
|
|
|
|
|
|
|
156 |
# Get data for 3D visualization
|
157 |
x_vals, y_vals, z_vals = zip(*pos.values())
|
158 |
|
159 |
fig = plt.figure(figsize=(16, 12))
|
160 |
ax = fig.add_subplot(111, projection='3d')
|
161 |
|
162 |
-
# Assign colors to
|
163 |
node_colors = []
|
164 |
for node, (x, prob, z) in pos.items():
|
165 |
if prob < 0.33:
|
@@ -178,18 +197,19 @@ def draw_global_tree_3d(G, filename='global_tree.png'):
|
|
178 |
x_end, y_end, z_end = pos[edge[1]]
|
179 |
ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color='gray', lw=2)
|
180 |
|
181 |
-
# Add labels to
|
182 |
for node, (x, y, z) in pos.items():
|
183 |
-
|
|
|
184 |
|
185 |
-
#
|
186 |
-
ax.set_xlabel('Time
|
187 |
-
ax.set_ylabel('
|
188 |
ax.set_zlabel('Event Number')
|
189 |
-
ax.set_title('Event Tree
|
190 |
|
191 |
-
plt.savefig(filename, bbox_inches='tight') # Save to
|
192 |
-
plt.close() # Close the figure to free
|
193 |
|
194 |
def main(mode, input_file=None):
|
195 |
G = nx.DiGraph()
|
@@ -197,12 +217,14 @@ def main(mode, input_file=None):
|
|
197 |
if mode == 'random':
|
198 |
starting_x = 0
|
199 |
starting_y = 0
|
200 |
-
max_depth = 5 # Maximum tree
|
201 |
max_nodes = 3 # Maximum number of child nodes
|
202 |
-
x_range = 10 # Maximum range for
|
203 |
|
204 |
-
# Generate the tree and get
|
205 |
generate_tree(starting_x, starting_y, 0, max_depth, max_nodes, x_range, G)
|
|
|
|
|
206 |
elif mode == 'json' and input_file:
|
207 |
with open(input_file, 'r') as file:
|
208 |
json_data = file.read()
|
@@ -211,10 +233,14 @@ def main(mode, input_file=None):
|
|
211 |
print("Invalid mode or input file not provided.")
|
212 |
return
|
213 |
|
|
|
|
|
|
|
|
|
214 |
# Find relevant paths
|
215 |
best_path, best_mean_prob, worst_path, worst_mean_prob, longest_duration_path, shortest_duration_path = find_paths(G)
|
216 |
|
217 |
-
# Print
|
218 |
if best_path:
|
219 |
print(f"\nPath with the highest average probability:")
|
220 |
print(" -> ".join(map(str, best_path)))
|
@@ -251,10 +277,14 @@ def main(mode, input_file=None):
|
|
251 |
if shortest_duration_path:
|
252 |
draw_path_3d(G, path=shortest_duration_path, filename='shortest_duration_path.png', highlight_color='purple')
|
253 |
|
|
|
|
|
254 |
if __name__ == "__main__":
|
255 |
if len(sys.argv) < 2:
|
256 |
-
print("Usage: python script.py <mode> [
|
257 |
else:
|
258 |
mode = sys.argv[1]
|
259 |
input_file = sys.argv[2] if len(sys.argv) > 2 else None
|
260 |
main(mode, input_file)
|
|
|
|
|
|
1 |
import matplotlib.pyplot as plt
|
2 |
from mpl_toolkits.mplot3d import Axes3D
|
3 |
import networkx as nx
|
|
|
4 |
import numpy as np
|
5 |
import json
|
6 |
import sys
|
7 |
+
import random
|
8 |
|
9 |
def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G, parent=None, node_count_per_depth=None):
|
10 |
+
"""Generates a tree of nodes with positions adjusted on the x-axis, y-axis, and number of nodes on the z-axis."""
|
11 |
if node_count_per_depth is None:
|
12 |
node_count_per_depth = {}
|
13 |
+
|
14 |
if depth not in node_count_per_depth:
|
15 |
node_count_per_depth[depth] = 0
|
16 |
|
|
|
19 |
|
20 |
num_children = random.randint(1, max_nodes)
|
21 |
x_positions = [current_x + i * x_range / (num_children + 1) for i in range(num_children)]
|
22 |
+
|
23 |
for x in x_positions:
|
24 |
# Add node to the graph
|
25 |
node_id = len(G.nodes)
|
26 |
node_count_per_depth[depth] += 1
|
27 |
prob = random.uniform(0, 1) # Assign random probability
|
28 |
+
G.add_node(node_id, pos=(x, prob, depth)) # Use `depth` for z position
|
29 |
if parent is not None:
|
30 |
G.add_edge(parent, node_id)
|
31 |
# Recursively add child nodes
|
|
|
33 |
|
34 |
return node_count_per_depth
|
35 |
|
36 |
+
|
37 |
+
|
38 |
def build_graph_from_json(json_data, G):
|
39 |
"""Builds a graph from JSON data."""
|
40 |
+
def add_event(parent_id, event_data, depth):
|
41 |
+
"""Recursively adds events and subevents to the graph."""
|
42 |
+
# Add the current event node
|
43 |
+
node_id = len(G.nodes)
|
44 |
+
prob = event_data['probability'] / 100.0 # Convert percentage to probability
|
45 |
+
pos = (depth, prob, event_data['event_number']) # Use event_number for z position
|
46 |
+
label = event_data['name'] # Use event name as label
|
47 |
+
G.add_node(node_id, pos=pos, label=label)
|
48 |
+
if parent_id is not None:
|
49 |
G.add_edge(parent_id, node_id)
|
|
|
|
|
50 |
|
51 |
+
# Add child events
|
52 |
+
subevents = event_data.get('subevents', {}).get('event', [])
|
53 |
+
if not isinstance(subevents, list):
|
54 |
+
subevents = [subevents] # Ensure subevents is a list
|
55 |
+
|
56 |
+
for subevent in subevents:
|
57 |
+
add_event(node_id, subevent, depth + 1)
|
58 |
+
|
59 |
data = json.loads(json_data)
|
60 |
+
root_id = len(G.nodes)
|
61 |
+
root_event = list(data.get('events', {}).values())[0]
|
62 |
+
G.add_node(root_id, pos=(0, root_event['probability'] / 100.0, root_event['event_number']), label=root_event['name'])
|
63 |
+
add_event(None, root_event, 0) # Start from the root
|
64 |
+
|
65 |
+
|
66 |
|
67 |
def find_paths(G):
|
68 |
+
"""Finds the paths with the highest and lowest average probability, and the longest and shortest durations in graph G."""
|
69 |
best_path = None
|
70 |
worst_path = None
|
71 |
longest_duration_path = None
|
|
|
84 |
if not all('pos' in G.nodes[node] for node in path):
|
85 |
continue # Skip paths with nodes missing the 'pos' attribute
|
86 |
|
87 |
+
# Calculate the mean probability of the path
|
88 |
+
probabilities = [G.nodes[node]['pos'][1] for node in path] # Get node probabilities
|
89 |
mean_prob = np.mean(probabilities)
|
90 |
|
91 |
+
# Evaluate path with the highest mean probability
|
92 |
if mean_prob > best_mean_prob:
|
93 |
best_mean_prob = mean_prob
|
94 |
best_path = path
|
95 |
|
96 |
+
# Evaluate path with the lowest mean probability
|
97 |
if mean_prob < worst_mean_prob:
|
98 |
worst_mean_prob = mean_prob
|
99 |
worst_path = path
|
100 |
|
101 |
+
# Calculate path duration
|
102 |
x_positions = [G.nodes[node]['pos'][0] for node in path]
|
103 |
duration = max(x_positions) - min(x_positions)
|
104 |
|
105 |
+
# Evaluate path with the longest duration
|
106 |
if duration > max_duration:
|
107 |
max_duration = duration
|
108 |
longest_duration_path = path
|
109 |
|
110 |
+
# Evaluate path with the shortest duration
|
111 |
if duration < min_duration:
|
112 |
min_duration = duration
|
113 |
shortest_duration_path = path
|
|
|
127 |
fig = plt.figure(figsize=(16, 12))
|
128 |
ax = fig.add_subplot(111, projection='3d')
|
129 |
|
130 |
+
# Assign colors to nodes based on probability
|
131 |
node_colors = []
|
132 |
for node in path:
|
133 |
prob = G.nodes[node]['pos'][1]
|
|
|
147 |
x_end, y_end, z_end = pos[edge[1]]
|
148 |
ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color=highlight_color, lw=2)
|
149 |
|
150 |
+
# Add labels to nodes
|
151 |
for node, (x, y, z) in pos.items():
|
152 |
if node in path:
|
153 |
ax.text(x, y, z, str(node), fontsize=12, color='black')
|
154 |
|
155 |
+
# Set labels and title
|
156 |
ax.set_xlabel('Time (weeks)')
|
157 |
ax.set_ylabel('Event Probability')
|
158 |
ax.set_zlabel('Event Number')
|
159 |
+
ax.set_title('3D Event Tree - Path')
|
160 |
+
|
161 |
+
plt.savefig(filename, bbox_inches='tight') # Save to file with adjusted margins
|
162 |
+
plt.close() # Close the figure to free resources
|
163 |
|
|
|
|
|
164 |
|
165 |
def draw_global_tree_3d(G, filename='global_tree.png'):
|
166 |
"""Draws the entire graph in 3D using networkx and matplotlib and saves the figure to a file."""
|
167 |
pos = nx.get_node_attributes(G, 'pos')
|
168 |
+
labels = nx.get_node_attributes(G, 'label')
|
169 |
|
170 |
+
# Check if the graph is empty
|
171 |
+
if not pos:
|
172 |
+
print("Graph is empty. No nodes to visualize.")
|
173 |
+
return
|
174 |
+
|
175 |
# Get data for 3D visualization
|
176 |
x_vals, y_vals, z_vals = zip(*pos.values())
|
177 |
|
178 |
fig = plt.figure(figsize=(16, 12))
|
179 |
ax = fig.add_subplot(111, projection='3d')
|
180 |
|
181 |
+
# Assign colors to nodes based on probability
|
182 |
node_colors = []
|
183 |
for node, (x, prob, z) in pos.items():
|
184 |
if prob < 0.33:
|
|
|
197 |
x_end, y_end, z_end = pos[edge[1]]
|
198 |
ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color='gray', lw=2)
|
199 |
|
200 |
+
# Add labels to nodes
|
201 |
for node, (x, y, z) in pos.items():
|
202 |
+
label = labels.get(node, f"{node}")
|
203 |
+
ax.text(x, y, z, label, fontsize=12, color='black')
|
204 |
|
205 |
+
# Set labels and title
|
206 |
+
ax.set_xlabel('Time')
|
207 |
+
ax.set_ylabel('Probability')
|
208 |
ax.set_zlabel('Event Number')
|
209 |
+
ax.set_title('3D Event Tree')
|
210 |
|
211 |
+
plt.savefig(filename, bbox_inches='tight') # Save to file with adjusted margins
|
212 |
+
plt.close() # Close the figure to free resources
|
213 |
|
214 |
def main(mode, input_file=None):
|
215 |
G = nx.DiGraph()
|
|
|
217 |
if mode == 'random':
|
218 |
starting_x = 0
|
219 |
starting_y = 0
|
220 |
+
max_depth = 5 # Maximum depth of the tree
|
221 |
max_nodes = 3 # Maximum number of child nodes
|
222 |
+
x_range = 10 # Maximum range for x position of nodes
|
223 |
|
224 |
+
# Generate the tree and get node count per depth
|
225 |
generate_tree(starting_x, starting_y, 0, max_depth, max_nodes, x_range, G)
|
226 |
+
|
227 |
+
|
228 |
elif mode == 'json' and input_file:
|
229 |
with open(input_file, 'r') as file:
|
230 |
json_data = file.read()
|
|
|
233 |
print("Invalid mode or input file not provided.")
|
234 |
return
|
235 |
|
236 |
+
# Save the global visualization
|
237 |
+
draw_global_tree_3d(G, filename='global_tree.png')
|
238 |
+
|
239 |
+
|
240 |
# Find relevant paths
|
241 |
best_path, best_mean_prob, worst_path, worst_mean_prob, longest_duration_path, shortest_duration_path = find_paths(G)
|
242 |
|
243 |
+
# Print results
|
244 |
if best_path:
|
245 |
print(f"\nPath with the highest average probability:")
|
246 |
print(" -> ".join(map(str, best_path)))
|
|
|
277 |
if shortest_duration_path:
|
278 |
draw_path_3d(G, path=shortest_duration_path, filename='shortest_duration_path.png', highlight_color='purple')
|
279 |
|
280 |
+
|
281 |
+
|
282 |
if __name__ == "__main__":
|
283 |
if len(sys.argv) < 2:
|
284 |
+
print("Usage: python script.py <mode> [input_file]")
|
285 |
else:
|
286 |
mode = sys.argv[1]
|
287 |
input_file = sys.argv[2] if len(sys.argv) > 2 else None
|
288 |
main(mode, input_file)
|
289 |
+
|
290 |
+
|