diff --git a/checkpoint/barc_complete/model_best.pth.tar b/checkpoint/barc_complete/model_best.pth.tar
new file mode 100644
index 0000000000000000000000000000000000000000..fd0fa82e10dc43d53ac35e6af2fc5bc9bb8bd3cf
--- /dev/null
+++ b/checkpoint/barc_complete/model_best.pth.tar
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0834c7f6a298a707e748da7185bd52a318697a34d7d0462e86cf57e287fa5da3
+size 549078471
diff --git a/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt b/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt
new file mode 100644
index 0000000000000000000000000000000000000000..7d7ae431a18aed3a51cee275dfc300d68f430487
--- /dev/null
+++ b/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4ff03508f6b9431da1c224697ce1c68cab000758215c4a4766e136c28f828f2d
+size 1725484
diff --git a/data/breed_data/NIHMS866262-supplement-2.xlsx b/data/breed_data/NIHMS866262-supplement-2.xlsx
new file mode 100644
index 0000000000000000000000000000000000000000..0bcea54381008d956311a639c52f19ec6b26d6c4
--- /dev/null
+++ b/data/breed_data/NIHMS866262-supplement-2.xlsx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd6301ec254452ecb86df745220bef98b69d59794429c5cb452b03bb76e17eae
+size 94169
diff --git a/data/breed_data/complete_abbrev_dict_v2.pkl b/data/breed_data/complete_abbrev_dict_v2.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..94987bc0d384c5ecc32c2eb7d3b54c43fe9e0a75
--- /dev/null
+++ b/data/breed_data/complete_abbrev_dict_v2.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2354d2c7e3b2f7ee88f41234e138b7828d58fa6618c0ed0d0d4b12febaee8801
+size 26517
diff --git a/data/breed_data/complete_summary_breeds_v2.pkl b/data/breed_data/complete_summary_breeds_v2.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..0464c827536776681b9c4fcc6ee435a57b82332e
--- /dev/null
+++ b/data/breed_data/complete_summary_breeds_v2.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95461e44d7a6924e1d9879711c865177ac7f15faa1ffb932cb42995c8eae3412
+size 89004
diff --git a/data/smal_data/mean_dog_bone_lengths.txt b/data/smal_data/mean_dog_bone_lengths.txt
new file mode 100644
index 0000000000000000000000000000000000000000..abf7bbf02f8fc4eda65bddf7a5c8eb3ab88d6e38
--- /dev/null
+++ b/data/smal_data/mean_dog_bone_lengths.txt
@@ -0,0 +1,34 @@
+0.0
+0.09044851362705231
+0.1525898575782776
+0.08656660467386246
+0.08330804109573364
+0.17591887712478638
+0.1955687403678894
+0.1663869321346283
+0.20741023123264313
+0.10695090889930725
+0.1955687403678894
+0.1663869321346283
+0.20741020143032074
+0.10695091634988785
+0.19678470492362976
+0.135447695851326
+0.10385762155056
+0.1951410472393036
+0.22369971871376038
+0.14296436309814453
+0.10385762155056
+0.1951410472393036
+0.22369973361492157
+0.14296436309814453
+0.11435563117265701
+0.1225045919418335
+0.055157795548439026
+0.07148551940917969
+0.0759430006146431
+0.09476413577795029
+0.0287716593593359
+0.11548781394958496
+0.15003003180027008
+0.15003003180027008
diff --git a/data/smal_data/my_smpl_SMBLD_nbj_v3.pkl b/data/smal_data/my_smpl_SMBLD_nbj_v3.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..7cf953b71028d76b12c08a4e04e90a44d0155e5e
--- /dev/null
+++ b/data/smal_data/my_smpl_SMBLD_nbj_v3.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bf01081234c09445ede7079083727705e6c13a21a77bf97f305e4ad6527f06df
+size 34904364
diff --git a/data/smal_data/my_smpl_data_SMBLD_v3.pkl b/data/smal_data/my_smpl_data_SMBLD_v3.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..0372a9723c39f810cd61076867844b76ce49917a
--- /dev/null
+++ b/data/smal_data/my_smpl_data_SMBLD_v3.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84ad0ef6f85d662464c4d0301adede172bb241158b1cea66a810a930a9473cc8
+size 31841
diff --git a/data/smal_data/symmetry_inds.json b/data/smal_data/symmetry_inds.json
new file mode 100644
index 0000000000000000000000000000000000000000..c17c305b15222b9cba70acd64b46725a2c0332c1
--- /dev/null
+++ b/data/smal_data/symmetry_inds.json
@@ -0,0 +1,3897 @@
+{
+ "center_inds": [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ 24,
+ 25,
+ 26,
+ 27,
+ 28,
+ 29,
+ 30,
+ 31,
+ 32,
+ 37,
+ 55,
+ 119,
+ 120,
+ 163,
+ 209,
+ 210,
+ 211,
+ 213,
+ 216,
+ 227,
+ 326,
+ 395,
+ 452,
+ 578,
+ 910,
+ 959,
+ 964,
+ 975,
+ 976,
+ 977,
+ 1172,
+ 1175,
+ 1176,
+ 1178,
+ 1194,
+ 1243,
+ 1739,
+ 1796,
+ 1797,
+ 1798,
+ 1799,
+ 1800,
+ 1801,
+ 1802,
+ 1803,
+ 1804,
+ 1805,
+ 1806,
+ 1807,
+ 1808,
+ 1809,
+ 1810,
+ 1811,
+ 1812,
+ 1813,
+ 1814,
+ 1815,
+ 1816,
+ 1817,
+ 1818,
+ 1819,
+ 1820,
+ 1821,
+ 1822,
+ 1823,
+ 1824,
+ 1825,
+ 1826,
+ 1827,
+ 1828,
+ 1829,
+ 1830,
+ 1831,
+ 1832,
+ 1833,
+ 1834,
+ 1835,
+ 1836,
+ 1837,
+ 1838,
+ 1839,
+ 1840,
+ 1842,
+ 1843,
+ 1844,
+ 1845,
+ 1846,
+ 1847,
+ 1848,
+ 1849,
+ 1850,
+ 1851,
+ 1852,
+ 1853,
+ 1854,
+ 1855,
+ 1856,
+ 1857,
+ 1858,
+ 1859,
+ 1860,
+ 1861,
+ 1862,
+ 1863,
+ 1870,
+ 1919,
+ 1960,
+ 1961,
+ 1965,
+ 1967,
+ 2003
+ ],
+ "left_inds": [
+ 2012,
+ 2013,
+ 2014,
+ 2015,
+ 2016,
+ 2017,
+ 2018,
+ 2019,
+ 2020,
+ 2021,
+ 2022,
+ 2023,
+ 2024,
+ 2025,
+ 2026,
+ 2027,
+ 2028,
+ 2029,
+ 2030,
+ 2031,
+ 2032,
+ 2033,
+ 2034,
+ 2035,
+ 2036,
+ 2037,
+ 2038,
+ 2039,
+ 2040,
+ 2041,
+ 2042,
+ 2043,
+ 2044,
+ 2045,
+ 2046,
+ 2047,
+ 2048,
+ 2049,
+ 2050,
+ 2051,
+ 2052,
+ 2053,
+ 2054,
+ 2055,
+ 2056,
+ 2057,
+ 2058,
+ 2059,
+ 2060,
+ 2061,
+ 2062,
+ 2063,
+ 2064,
+ 2065,
+ 2066,
+ 2067,
+ 2068,
+ 2069,
+ 2070,
+ 2071,
+ 2072,
+ 2073,
+ 2074,
+ 2075,
+ 2076,
+ 2077,
+ 2078,
+ 2079,
+ 2080,
+ 2081,
+ 2082,
+ 2083,
+ 2084,
+ 2085,
+ 2086,
+ 2087,
+ 2088,
+ 2089,
+ 2090,
+ 2091,
+ 2092,
+ 2093,
+ 2094,
+ 2095,
+ 2096,
+ 2097,
+ 2098,
+ 2099,
+ 2100,
+ 2101,
+ 2102,
+ 2103,
+ 2104,
+ 2105,
+ 2106,
+ 2107,
+ 2108,
+ 2109,
+ 2110,
+ 2111,
+ 2112,
+ 2113,
+ 2114,
+ 2115,
+ 2116,
+ 2117,
+ 2118,
+ 2119,
+ 2120,
+ 2121,
+ 2122,
+ 2123,
+ 2124,
+ 2125,
+ 2126,
+ 2127,
+ 2128,
+ 2129,
+ 2130,
+ 2131,
+ 2132,
+ 2133,
+ 2134,
+ 2135,
+ 2136,
+ 2137,
+ 2138,
+ 2139,
+ 2140,
+ 2141,
+ 2142,
+ 2143,
+ 2144,
+ 2145,
+ 2146,
+ 2147,
+ 2148,
+ 2149,
+ 2150,
+ 2151,
+ 2152,
+ 2153,
+ 2154,
+ 2155,
+ 2156,
+ 2157,
+ 2158,
+ 2159,
+ 2160,
+ 2161,
+ 2162,
+ 2163,
+ 2164,
+ 2165,
+ 2166,
+ 2167,
+ 2168,
+ 2169,
+ 2170,
+ 2171,
+ 2172,
+ 2173,
+ 2174,
+ 2175,
+ 2176,
+ 2177,
+ 2178,
+ 2179,
+ 2180,
+ 2181,
+ 2182,
+ 2183,
+ 2184,
+ 2185,
+ 2186,
+ 2187,
+ 2188,
+ 2189,
+ 2190,
+ 2191,
+ 2192,
+ 2193,
+ 2194,
+ 2195,
+ 2196,
+ 2197,
+ 2198,
+ 2199,
+ 2200,
+ 2201,
+ 2202,
+ 2203,
+ 2204,
+ 2205,
+ 2206,
+ 2207,
+ 2208,
+ 2209,
+ 2210,
+ 2211,
+ 2212,
+ 2213,
+ 2214,
+ 2215,
+ 2216,
+ 2217,
+ 2218,
+ 2219,
+ 2220,
+ 2221,
+ 2222,
+ 2223,
+ 2224,
+ 2225,
+ 2226,
+ 2227,
+ 2228,
+ 2229,
+ 2230,
+ 2231,
+ 2232,
+ 2233,
+ 2234,
+ 2235,
+ 2236,
+ 2237,
+ 2238,
+ 2239,
+ 2240,
+ 2241,
+ 2242,
+ 2243,
+ 2244,
+ 2245,
+ 2246,
+ 2247,
+ 2248,
+ 2249,
+ 2250,
+ 2251,
+ 2252,
+ 2253,
+ 2254,
+ 2255,
+ 2256,
+ 2257,
+ 2258,
+ 2259,
+ 2260,
+ 2261,
+ 2262,
+ 2263,
+ 2264,
+ 2265,
+ 2266,
+ 2267,
+ 2268,
+ 2269,
+ 2270,
+ 2271,
+ 2272,
+ 2273,
+ 2274,
+ 2275,
+ 2276,
+ 2277,
+ 2278,
+ 2279,
+ 2280,
+ 2281,
+ 2282,
+ 2283,
+ 2284,
+ 2285,
+ 2286,
+ 2287,
+ 2288,
+ 2289,
+ 2290,
+ 2291,
+ 2292,
+ 2293,
+ 2294,
+ 2295,
+ 2296,
+ 2297,
+ 2298,
+ 2299,
+ 2300,
+ 2301,
+ 2302,
+ 2303,
+ 2304,
+ 2305,
+ 2306,
+ 2307,
+ 2308,
+ 2309,
+ 2310,
+ 2311,
+ 2312,
+ 2313,
+ 2314,
+ 2315,
+ 2316,
+ 2317,
+ 2318,
+ 2319,
+ 2320,
+ 2321,
+ 2322,
+ 2323,
+ 2324,
+ 2325,
+ 2326,
+ 2327,
+ 2328,
+ 2329,
+ 2330,
+ 2331,
+ 2332,
+ 2333,
+ 2334,
+ 2335,
+ 2336,
+ 2337,
+ 2338,
+ 2339,
+ 2340,
+ 2341,
+ 2342,
+ 2343,
+ 2344,
+ 2345,
+ 2346,
+ 2347,
+ 2348,
+ 2349,
+ 2350,
+ 2351,
+ 2352,
+ 2353,
+ 2354,
+ 2355,
+ 2356,
+ 2357,
+ 2358,
+ 2359,
+ 2360,
+ 2361,
+ 2362,
+ 2363,
+ 2364,
+ 2365,
+ 2366,
+ 2367,
+ 2368,
+ 2369,
+ 2370,
+ 2371,
+ 2372,
+ 2373,
+ 2374,
+ 2375,
+ 2376,
+ 2377,
+ 2378,
+ 2379,
+ 2380,
+ 2381,
+ 2382,
+ 2383,
+ 2384,
+ 2385,
+ 2386,
+ 2387,
+ 2388,
+ 2389,
+ 2390,
+ 2391,
+ 2392,
+ 2393,
+ 2394,
+ 2395,
+ 2396,
+ 2397,
+ 2398,
+ 2399,
+ 2400,
+ 2401,
+ 2402,
+ 2403,
+ 2404,
+ 2405,
+ 2406,
+ 2407,
+ 2408,
+ 2409,
+ 2410,
+ 2411,
+ 2412,
+ 2413,
+ 2414,
+ 2415,
+ 2416,
+ 2417,
+ 2418,
+ 2419,
+ 2420,
+ 2421,
+ 2422,
+ 2423,
+ 2424,
+ 2425,
+ 2426,
+ 2427,
+ 2428,
+ 2429,
+ 2430,
+ 2431,
+ 2432,
+ 2433,
+ 2434,
+ 2435,
+ 2436,
+ 2437,
+ 2438,
+ 2439,
+ 2440,
+ 2441,
+ 2442,
+ 2443,
+ 2444,
+ 2445,
+ 2446,
+ 2447,
+ 2448,
+ 2449,
+ 2450,
+ 2451,
+ 2452,
+ 2453,
+ 2454,
+ 2455,
+ 2456,
+ 2457,
+ 2458,
+ 2459,
+ 2460,
+ 2461,
+ 2462,
+ 2463,
+ 2464,
+ 2465,
+ 2466,
+ 2467,
+ 2468,
+ 2469,
+ 2470,
+ 2471,
+ 2472,
+ 2473,
+ 2474,
+ 2475,
+ 2476,
+ 2477,
+ 2478,
+ 2479,
+ 2480,
+ 2481,
+ 2482,
+ 2483,
+ 2484,
+ 2485,
+ 2486,
+ 2487,
+ 2488,
+ 2489,
+ 2490,
+ 2491,
+ 2492,
+ 2493,
+ 2494,
+ 2495,
+ 2496,
+ 2497,
+ 2498,
+ 2499,
+ 2500,
+ 2501,
+ 2502,
+ 2503,
+ 2504,
+ 2505,
+ 2506,
+ 2507,
+ 2508,
+ 2509,
+ 2510,
+ 2511,
+ 2512,
+ 2513,
+ 2514,
+ 2515,
+ 2516,
+ 2517,
+ 2518,
+ 2519,
+ 2520,
+ 2521,
+ 2522,
+ 2523,
+ 2524,
+ 2525,
+ 2526,
+ 2527,
+ 2528,
+ 2529,
+ 2530,
+ 2531,
+ 2532,
+ 2533,
+ 2534,
+ 2535,
+ 2536,
+ 2537,
+ 2538,
+ 2539,
+ 2540,
+ 2541,
+ 2542,
+ 2543,
+ 2544,
+ 2545,
+ 2546,
+ 2547,
+ 2548,
+ 2549,
+ 2550,
+ 2551,
+ 2552,
+ 2553,
+ 2554,
+ 2555,
+ 2556,
+ 2557,
+ 2558,
+ 2559,
+ 2560,
+ 2561,
+ 2562,
+ 2563,
+ 2564,
+ 2565,
+ 2566,
+ 2567,
+ 2568,
+ 2569,
+ 2570,
+ 2571,
+ 2572,
+ 2573,
+ 2574,
+ 2575,
+ 2576,
+ 2577,
+ 2578,
+ 2579,
+ 2580,
+ 2581,
+ 2582,
+ 2583,
+ 2584,
+ 2585,
+ 2586,
+ 2587,
+ 2588,
+ 2589,
+ 2590,
+ 2591,
+ 2592,
+ 2593,
+ 2594,
+ 2595,
+ 2596,
+ 2597,
+ 2598,
+ 2599,
+ 2600,
+ 2601,
+ 2602,
+ 2603,
+ 2604,
+ 2605,
+ 2606,
+ 2607,
+ 2608,
+ 2609,
+ 2610,
+ 2611,
+ 2612,
+ 2613,
+ 2614,
+ 2615,
+ 2616,
+ 2617,
+ 2618,
+ 2619,
+ 2620,
+ 2621,
+ 2622,
+ 2623,
+ 2624,
+ 2625,
+ 2626,
+ 2627,
+ 2628,
+ 2629,
+ 2630,
+ 2631,
+ 2632,
+ 2633,
+ 2634,
+ 2635,
+ 2636,
+ 2637,
+ 2638,
+ 2639,
+ 2640,
+ 2641,
+ 2642,
+ 2643,
+ 2644,
+ 2645,
+ 2646,
+ 2647,
+ 2648,
+ 2649,
+ 2650,
+ 2651,
+ 2652,
+ 2653,
+ 2654,
+ 2655,
+ 2656,
+ 2657,
+ 2658,
+ 2659,
+ 2660,
+ 2661,
+ 2662,
+ 2663,
+ 2664,
+ 2665,
+ 2666,
+ 2667,
+ 2668,
+ 2669,
+ 2670,
+ 2671,
+ 2672,
+ 2673,
+ 2674,
+ 2675,
+ 2676,
+ 2677,
+ 2678,
+ 2679,
+ 2680,
+ 2681,
+ 2682,
+ 2683,
+ 2684,
+ 2685,
+ 2686,
+ 2687,
+ 2688,
+ 2689,
+ 2690,
+ 2691,
+ 2692,
+ 2693,
+ 2694,
+ 2695,
+ 2696,
+ 2697,
+ 2698,
+ 2699,
+ 2700,
+ 2701,
+ 2702,
+ 2703,
+ 2704,
+ 2705,
+ 2706,
+ 2707,
+ 2708,
+ 2709,
+ 2710,
+ 2711,
+ 2712,
+ 2713,
+ 2714,
+ 2715,
+ 2716,
+ 2717,
+ 2718,
+ 2719,
+ 2720,
+ 2721,
+ 2722,
+ 2723,
+ 2724,
+ 2725,
+ 2726,
+ 2727,
+ 2728,
+ 2729,
+ 2730,
+ 2731,
+ 2732,
+ 2733,
+ 2734,
+ 2735,
+ 2736,
+ 2737,
+ 2738,
+ 2739,
+ 2740,
+ 2741,
+ 2742,
+ 2743,
+ 2744,
+ 2745,
+ 2746,
+ 2747,
+ 2748,
+ 2749,
+ 2750,
+ 2751,
+ 2752,
+ 2753,
+ 2754,
+ 2755,
+ 2756,
+ 2757,
+ 2758,
+ 2759,
+ 2760,
+ 2761,
+ 2762,
+ 2763,
+ 2764,
+ 2765,
+ 2766,
+ 2767,
+ 2768,
+ 2769,
+ 2770,
+ 2771,
+ 2772,
+ 2773,
+ 2774,
+ 2775,
+ 2776,
+ 2777,
+ 2778,
+ 2779,
+ 2780,
+ 2781,
+ 2782,
+ 2783,
+ 2784,
+ 2785,
+ 2786,
+ 2787,
+ 2788,
+ 2789,
+ 2790,
+ 2791,
+ 2792,
+ 2793,
+ 2794,
+ 2795,
+ 2796,
+ 2797,
+ 2798,
+ 2799,
+ 2800,
+ 2801,
+ 2802,
+ 2803,
+ 2804,
+ 2805,
+ 2806,
+ 2807,
+ 2808,
+ 2809,
+ 2810,
+ 2811,
+ 2812,
+ 2813,
+ 2814,
+ 2815,
+ 2816,
+ 2817,
+ 2818,
+ 2819,
+ 2820,
+ 2821,
+ 2822,
+ 2823,
+ 2824,
+ 2825,
+ 2826,
+ 2827,
+ 2828,
+ 2829,
+ 2830,
+ 2831,
+ 2832,
+ 2833,
+ 2834,
+ 2835,
+ 2836,
+ 2837,
+ 2838,
+ 2839,
+ 2840,
+ 2841,
+ 2842,
+ 2843,
+ 2844,
+ 2845,
+ 2846,
+ 2847,
+ 2848,
+ 2849,
+ 2850,
+ 2851,
+ 2852,
+ 2853,
+ 2854,
+ 2855,
+ 2856,
+ 2857,
+ 2858,
+ 2859,
+ 2860,
+ 2861,
+ 2862,
+ 2863,
+ 2864,
+ 2865,
+ 2866,
+ 2867,
+ 2868,
+ 2869,
+ 2870,
+ 2871,
+ 2872,
+ 2873,
+ 2874,
+ 2875,
+ 2876,
+ 2877,
+ 2878,
+ 2879,
+ 2880,
+ 2881,
+ 2882,
+ 2883,
+ 2884,
+ 2885,
+ 2886,
+ 2887,
+ 2888,
+ 2889,
+ 2890,
+ 2891,
+ 2892,
+ 2893,
+ 2894,
+ 2895,
+ 2896,
+ 2897,
+ 2898,
+ 2899,
+ 2900,
+ 2901,
+ 2902,
+ 2903,
+ 2904,
+ 2905,
+ 2906,
+ 2907,
+ 2908,
+ 2909,
+ 2910,
+ 2911,
+ 2912,
+ 2913,
+ 2914,
+ 2915,
+ 2916,
+ 2917,
+ 2918,
+ 2919,
+ 2920,
+ 2921,
+ 2922,
+ 2923,
+ 2924,
+ 2925,
+ 2926,
+ 2927,
+ 2928,
+ 2929,
+ 2930,
+ 2931,
+ 2932,
+ 2933,
+ 2934,
+ 2935,
+ 2936,
+ 2937,
+ 2938,
+ 2939,
+ 2940,
+ 2941,
+ 2942,
+ 2943,
+ 2944,
+ 2945,
+ 2946,
+ 2947,
+ 2948,
+ 2949,
+ 2950,
+ 2951,
+ 2952,
+ 2953,
+ 2954,
+ 2955,
+ 2956,
+ 2957,
+ 2958,
+ 2959,
+ 2960,
+ 2961,
+ 2962,
+ 2963,
+ 2964,
+ 2965,
+ 2966,
+ 2967,
+ 2968,
+ 2969,
+ 2970,
+ 2971,
+ 2972,
+ 2973,
+ 2974,
+ 2975,
+ 2976,
+ 2977,
+ 2978,
+ 2979,
+ 2980,
+ 2981,
+ 2982,
+ 2983,
+ 2984,
+ 2985,
+ 2986,
+ 2987,
+ 2988,
+ 2989,
+ 2990,
+ 2991,
+ 2992,
+ 2993,
+ 2994,
+ 2995,
+ 2996,
+ 2997,
+ 2998,
+ 2999,
+ 3000,
+ 3001,
+ 3002,
+ 3003,
+ 3004,
+ 3005,
+ 3006,
+ 3007,
+ 3008,
+ 3009,
+ 3010,
+ 3011,
+ 3012,
+ 3013,
+ 3014,
+ 3015,
+ 3016,
+ 3017,
+ 3018,
+ 3019,
+ 3020,
+ 3021,
+ 3022,
+ 3023,
+ 3024,
+ 3025,
+ 3026,
+ 3027,
+ 3028,
+ 3029,
+ 3030,
+ 3031,
+ 3032,
+ 3033,
+ 3034,
+ 3035,
+ 3036,
+ 3037,
+ 3038,
+ 3039,
+ 3040,
+ 3041,
+ 3042,
+ 3043,
+ 3044,
+ 3045,
+ 3046,
+ 3047,
+ 3048,
+ 3049,
+ 3050,
+ 3051,
+ 3052,
+ 3053,
+ 3054,
+ 3055,
+ 3056,
+ 3057,
+ 3058,
+ 3059,
+ 3060,
+ 3061,
+ 3062,
+ 3063,
+ 3064,
+ 3065,
+ 3066,
+ 3067,
+ 3068,
+ 3069,
+ 3070,
+ 3071,
+ 3072,
+ 3073,
+ 3074,
+ 3075,
+ 3076,
+ 3077,
+ 3078,
+ 3079,
+ 3080,
+ 3081,
+ 3082,
+ 3083,
+ 3084,
+ 3085,
+ 3086,
+ 3087,
+ 3088,
+ 3089,
+ 3090,
+ 3091,
+ 3092,
+ 3093,
+ 3094,
+ 3095,
+ 3096,
+ 3097,
+ 3098,
+ 3099,
+ 3100,
+ 3101,
+ 3102,
+ 3103,
+ 3104,
+ 3105,
+ 3106,
+ 3107,
+ 3108,
+ 3109,
+ 3110,
+ 3111,
+ 3112,
+ 3113,
+ 3114,
+ 3115,
+ 3116,
+ 3117,
+ 3118,
+ 3119,
+ 3120,
+ 3121,
+ 3122,
+ 3123,
+ 3124,
+ 3125,
+ 3126,
+ 3127,
+ 3128,
+ 3129,
+ 3130,
+ 3131,
+ 3132,
+ 3133,
+ 3134,
+ 3135,
+ 3136,
+ 3137,
+ 3138,
+ 3139,
+ 3140,
+ 3141,
+ 3142,
+ 3143,
+ 3144,
+ 3145,
+ 3146,
+ 3147,
+ 3148,
+ 3149,
+ 3150,
+ 3151,
+ 3152,
+ 3153,
+ 3154,
+ 3155,
+ 3156,
+ 3157,
+ 3158,
+ 3159,
+ 3160,
+ 3161,
+ 3162,
+ 3163,
+ 3164,
+ 3165,
+ 3166,
+ 3167,
+ 3168,
+ 3169,
+ 3170,
+ 3171,
+ 3172,
+ 3173,
+ 3174,
+ 3175,
+ 3176,
+ 3177,
+ 3178,
+ 3179,
+ 3180,
+ 3181,
+ 3182,
+ 3183,
+ 3184,
+ 3185,
+ 3186,
+ 3187,
+ 3188,
+ 3189,
+ 3190,
+ 3191,
+ 3192,
+ 3193,
+ 3194,
+ 3195,
+ 3196,
+ 3197,
+ 3198,
+ 3199,
+ 3200,
+ 3201,
+ 3202,
+ 3203,
+ 3204,
+ 3205,
+ 3206,
+ 3207,
+ 3208,
+ 3209,
+ 3210,
+ 3211,
+ 3212,
+ 3213,
+ 3214,
+ 3215,
+ 3216,
+ 3217,
+ 3218,
+ 3219,
+ 3220,
+ 3221,
+ 3222,
+ 3223,
+ 3224,
+ 3225,
+ 3226,
+ 3227,
+ 3228,
+ 3229,
+ 3230,
+ 3231,
+ 3232,
+ 3233,
+ 3234,
+ 3235,
+ 3236,
+ 3237,
+ 3238,
+ 3239,
+ 3240,
+ 3241,
+ 3242,
+ 3243,
+ 3244,
+ 3245,
+ 3246,
+ 3247,
+ 3248,
+ 3249,
+ 3250,
+ 3251,
+ 3252,
+ 3253,
+ 3254,
+ 3255,
+ 3256,
+ 3257,
+ 3258,
+ 3259,
+ 3260,
+ 3261,
+ 3262,
+ 3263,
+ 3264,
+ 3265,
+ 3266,
+ 3267,
+ 3268,
+ 3269,
+ 3270,
+ 3271,
+ 3272,
+ 3273,
+ 3274,
+ 3275,
+ 3276,
+ 3277,
+ 3278,
+ 3279,
+ 3280,
+ 3281,
+ 3282,
+ 3283,
+ 3284,
+ 3285,
+ 3286,
+ 3287,
+ 3288,
+ 3289,
+ 3290,
+ 3291,
+ 3292,
+ 3293,
+ 3294,
+ 3295,
+ 3296,
+ 3297,
+ 3298,
+ 3299,
+ 3300,
+ 3301,
+ 3302,
+ 3303,
+ 3304,
+ 3305,
+ 3306,
+ 3307,
+ 3308,
+ 3309,
+ 3310,
+ 3311,
+ 3312,
+ 3313,
+ 3314,
+ 3315,
+ 3316,
+ 3317,
+ 3318,
+ 3319,
+ 3320,
+ 3321,
+ 3322,
+ 3323,
+ 3324,
+ 3325,
+ 3326,
+ 3327,
+ 3328,
+ 3329,
+ 3330,
+ 3331,
+ 3332,
+ 3333,
+ 3334,
+ 3335,
+ 3336,
+ 3337,
+ 3338,
+ 3339,
+ 3340,
+ 3341,
+ 3342,
+ 3343,
+ 3344,
+ 3345,
+ 3346,
+ 3347,
+ 3348,
+ 3349,
+ 3350,
+ 3351,
+ 3352,
+ 3353,
+ 3354,
+ 3355,
+ 3356,
+ 3357,
+ 3358,
+ 3359,
+ 3360,
+ 3361,
+ 3362,
+ 3363,
+ 3364,
+ 3365,
+ 3366,
+ 3367,
+ 3368,
+ 3369,
+ 3370,
+ 3371,
+ 3372,
+ 3373,
+ 3374,
+ 3375,
+ 3376,
+ 3377,
+ 3378,
+ 3379,
+ 3380,
+ 3381,
+ 3382,
+ 3383,
+ 3384,
+ 3385,
+ 3386,
+ 3387,
+ 3388,
+ 3389,
+ 3390,
+ 3391,
+ 3392,
+ 3393,
+ 3394,
+ 3395,
+ 3396,
+ 3397,
+ 3398,
+ 3399,
+ 3400,
+ 3401,
+ 3402,
+ 3403,
+ 3404,
+ 3405,
+ 3406,
+ 3407,
+ 3408,
+ 3409,
+ 3410,
+ 3411,
+ 3412,
+ 3413,
+ 3414,
+ 3415,
+ 3416,
+ 3417,
+ 3418,
+ 3419,
+ 3420,
+ 3421,
+ 3422,
+ 3423,
+ 3424,
+ 3425,
+ 3426,
+ 3427,
+ 3428,
+ 3429,
+ 3430,
+ 3431,
+ 3432,
+ 3433,
+ 3434,
+ 3435,
+ 3436,
+ 3437,
+ 3438,
+ 3439,
+ 3440,
+ 3441,
+ 3442,
+ 3443,
+ 3444,
+ 3445,
+ 3446,
+ 3447,
+ 3448,
+ 3449,
+ 3450,
+ 3451,
+ 3452,
+ 3453,
+ 3454,
+ 3455,
+ 3456,
+ 3457,
+ 3458,
+ 3459,
+ 3460,
+ 3461,
+ 3462,
+ 3463,
+ 3464,
+ 3465,
+ 3466,
+ 3467,
+ 3468,
+ 3469,
+ 3470,
+ 3471,
+ 3472,
+ 3473,
+ 3474,
+ 3475,
+ 3476,
+ 3477,
+ 3478,
+ 3479,
+ 3480,
+ 3481,
+ 3482,
+ 3483,
+ 3484,
+ 3485,
+ 3486,
+ 3487,
+ 3488,
+ 3489,
+ 3490,
+ 3491,
+ 3492,
+ 3493,
+ 3494,
+ 3495,
+ 3496,
+ 3497,
+ 3498,
+ 3499,
+ 3500,
+ 3501,
+ 3502,
+ 3503,
+ 3504,
+ 3505,
+ 3506,
+ 3507,
+ 3508,
+ 3509,
+ 3510,
+ 3511,
+ 3512,
+ 3513,
+ 3514,
+ 3515,
+ 3516,
+ 3517,
+ 3518,
+ 3519,
+ 3520,
+ 3521,
+ 3522,
+ 3523,
+ 3524,
+ 3525,
+ 3526,
+ 3527,
+ 3528,
+ 3529,
+ 3530,
+ 3531,
+ 3532,
+ 3533,
+ 3534,
+ 3535,
+ 3536,
+ 3537,
+ 3538,
+ 3539,
+ 3540,
+ 3541,
+ 3542,
+ 3543,
+ 3544,
+ 3545,
+ 3546,
+ 3547,
+ 3548,
+ 3549,
+ 3550,
+ 3551,
+ 3552,
+ 3553,
+ 3554,
+ 3555,
+ 3556,
+ 3557,
+ 3558,
+ 3559,
+ 3560,
+ 3561,
+ 3562,
+ 3563,
+ 3564,
+ 3565,
+ 3566,
+ 3567,
+ 3568,
+ 3569,
+ 3570,
+ 3571,
+ 3572,
+ 3573,
+ 3574,
+ 3575,
+ 3576,
+ 3577,
+ 3578,
+ 3579,
+ 3580,
+ 3581,
+ 3582,
+ 3583,
+ 3584,
+ 3585,
+ 3586,
+ 3587,
+ 3588,
+ 3589,
+ 3590,
+ 3591,
+ 3592,
+ 3593,
+ 3594,
+ 3595,
+ 3596,
+ 3597,
+ 3598,
+ 3599,
+ 3600,
+ 3601,
+ 3602,
+ 3603,
+ 3604,
+ 3605,
+ 3606,
+ 3607,
+ 3608,
+ 3609,
+ 3610,
+ 3611,
+ 3612,
+ 3613,
+ 3614,
+ 3615,
+ 3616,
+ 3617,
+ 3618,
+ 3619,
+ 3620,
+ 3621,
+ 3622,
+ 3623,
+ 3624,
+ 3625,
+ 3626,
+ 3627,
+ 3628,
+ 3629,
+ 3630,
+ 3631,
+ 3632,
+ 3633,
+ 3634,
+ 3635,
+ 3636,
+ 3637,
+ 3638,
+ 3639,
+ 3640,
+ 3641,
+ 3642,
+ 3643,
+ 3644,
+ 3645,
+ 3646,
+ 3647,
+ 3648,
+ 3649,
+ 3650,
+ 3651,
+ 3652,
+ 3653,
+ 3654,
+ 3655,
+ 3656,
+ 3657,
+ 3658,
+ 3659,
+ 3660,
+ 3661,
+ 3662,
+ 3663,
+ 3664,
+ 3665,
+ 3666,
+ 3667,
+ 3668,
+ 3669,
+ 3670,
+ 3671,
+ 3672,
+ 3673,
+ 3674,
+ 3675,
+ 3676,
+ 3677,
+ 3678,
+ 3679,
+ 3680,
+ 3681,
+ 3682,
+ 3683,
+ 3684,
+ 3685,
+ 3686,
+ 3687,
+ 3688,
+ 3689,
+ 3690,
+ 3691,
+ 3692,
+ 3693,
+ 3694,
+ 3695,
+ 3696,
+ 3697,
+ 3698,
+ 3699,
+ 3700,
+ 3701,
+ 3702,
+ 3703,
+ 3704,
+ 3705,
+ 3706,
+ 3707,
+ 3708,
+ 3709,
+ 3710,
+ 3711,
+ 3712,
+ 3713,
+ 3714,
+ 3715,
+ 3716,
+ 3717,
+ 3718,
+ 3719,
+ 3720,
+ 3721,
+ 3722,
+ 3723,
+ 3724,
+ 3725,
+ 3726,
+ 3727,
+ 3728,
+ 3729,
+ 3730,
+ 3731,
+ 3732,
+ 3733,
+ 3734,
+ 3735,
+ 3736,
+ 3737,
+ 3738,
+ 3739,
+ 3740,
+ 3741,
+ 3742,
+ 3743,
+ 3744,
+ 3745,
+ 3746,
+ 3747,
+ 3748,
+ 3749,
+ 3750,
+ 3751,
+ 3752,
+ 3753,
+ 3754,
+ 3755,
+ 3756,
+ 3757,
+ 3758,
+ 3759,
+ 3760,
+ 3761,
+ 3762,
+ 3763,
+ 3764,
+ 3765,
+ 3766,
+ 3767,
+ 3768,
+ 3769,
+ 3770,
+ 3771,
+ 3772,
+ 3773,
+ 3774,
+ 3775,
+ 3776,
+ 3777,
+ 3778,
+ 3779,
+ 3780,
+ 3781,
+ 3782,
+ 3783,
+ 3784,
+ 3785,
+ 3786,
+ 3787,
+ 3788,
+ 3789,
+ 3790,
+ 3791,
+ 3792,
+ 3793,
+ 3794,
+ 3795,
+ 3796,
+ 3797,
+ 3798,
+ 3799,
+ 3800,
+ 3801,
+ 3802,
+ 3803,
+ 3804,
+ 3805,
+ 3806,
+ 3807,
+ 3808,
+ 3809,
+ 3810,
+ 3811,
+ 3812,
+ 3813,
+ 3814,
+ 3815,
+ 3816,
+ 3817,
+ 3818,
+ 3819,
+ 3820,
+ 3821,
+ 3822,
+ 3823,
+ 3824,
+ 3825,
+ 3826,
+ 3827,
+ 3828,
+ 3829,
+ 3830,
+ 3831,
+ 3832,
+ 3833,
+ 3834,
+ 3835,
+ 3836,
+ 3837,
+ 3838,
+ 3839,
+ 3840,
+ 3841,
+ 3842,
+ 3843,
+ 3844,
+ 3845,
+ 3846,
+ 3847,
+ 3848,
+ 3849,
+ 3850,
+ 3851,
+ 3852,
+ 3853,
+ 3854,
+ 3855,
+ 3856,
+ 3857,
+ 3858,
+ 3859,
+ 3860,
+ 3861,
+ 3862,
+ 3863,
+ 3864,
+ 3865,
+ 3866,
+ 3867,
+ 3868,
+ 3869,
+ 3870,
+ 3871,
+ 3872,
+ 3873,
+ 3874,
+ 3875,
+ 3876,
+ 3877,
+ 3878,
+ 3879,
+ 3880,
+ 3881,
+ 3882,
+ 3883,
+ 3884,
+ 3885,
+ 3886,
+ 3887,
+ 3888
+ ],
+ "right_inds": [
+ 33,
+ 34,
+ 35,
+ 36,
+ 38,
+ 39,
+ 40,
+ 41,
+ 42,
+ 43,
+ 44,
+ 45,
+ 46,
+ 47,
+ 48,
+ 49,
+ 50,
+ 51,
+ 52,
+ 53,
+ 54,
+ 56,
+ 57,
+ 58,
+ 59,
+ 60,
+ 61,
+ 62,
+ 63,
+ 64,
+ 65,
+ 66,
+ 67,
+ 68,
+ 69,
+ 70,
+ 71,
+ 72,
+ 73,
+ 74,
+ 75,
+ 76,
+ 77,
+ 78,
+ 79,
+ 80,
+ 81,
+ 82,
+ 83,
+ 84,
+ 85,
+ 86,
+ 87,
+ 88,
+ 89,
+ 90,
+ 91,
+ 92,
+ 93,
+ 94,
+ 95,
+ 96,
+ 97,
+ 98,
+ 99,
+ 100,
+ 101,
+ 102,
+ 103,
+ 104,
+ 105,
+ 106,
+ 107,
+ 108,
+ 109,
+ 110,
+ 111,
+ 112,
+ 113,
+ 114,
+ 115,
+ 116,
+ 117,
+ 118,
+ 121,
+ 122,
+ 123,
+ 124,
+ 125,
+ 126,
+ 127,
+ 128,
+ 129,
+ 130,
+ 131,
+ 132,
+ 133,
+ 134,
+ 135,
+ 136,
+ 137,
+ 138,
+ 139,
+ 140,
+ 141,
+ 142,
+ 143,
+ 144,
+ 145,
+ 146,
+ 147,
+ 148,
+ 149,
+ 150,
+ 151,
+ 152,
+ 153,
+ 154,
+ 155,
+ 156,
+ 157,
+ 158,
+ 159,
+ 160,
+ 161,
+ 162,
+ 164,
+ 165,
+ 166,
+ 167,
+ 168,
+ 169,
+ 170,
+ 171,
+ 172,
+ 173,
+ 174,
+ 175,
+ 176,
+ 177,
+ 178,
+ 179,
+ 180,
+ 181,
+ 182,
+ 183,
+ 184,
+ 185,
+ 186,
+ 187,
+ 188,
+ 189,
+ 190,
+ 191,
+ 192,
+ 193,
+ 194,
+ 195,
+ 196,
+ 197,
+ 198,
+ 199,
+ 200,
+ 201,
+ 202,
+ 203,
+ 204,
+ 205,
+ 206,
+ 207,
+ 208,
+ 212,
+ 214,
+ 215,
+ 217,
+ 218,
+ 219,
+ 220,
+ 221,
+ 222,
+ 223,
+ 224,
+ 225,
+ 226,
+ 228,
+ 229,
+ 230,
+ 231,
+ 232,
+ 233,
+ 234,
+ 235,
+ 236,
+ 237,
+ 238,
+ 239,
+ 240,
+ 241,
+ 242,
+ 243,
+ 244,
+ 245,
+ 246,
+ 247,
+ 248,
+ 249,
+ 250,
+ 251,
+ 252,
+ 253,
+ 254,
+ 255,
+ 256,
+ 257,
+ 258,
+ 259,
+ 260,
+ 261,
+ 262,
+ 263,
+ 264,
+ 265,
+ 266,
+ 267,
+ 268,
+ 269,
+ 270,
+ 271,
+ 272,
+ 273,
+ 274,
+ 275,
+ 276,
+ 277,
+ 278,
+ 279,
+ 280,
+ 281,
+ 282,
+ 283,
+ 284,
+ 285,
+ 286,
+ 287,
+ 288,
+ 289,
+ 290,
+ 291,
+ 292,
+ 293,
+ 294,
+ 295,
+ 296,
+ 297,
+ 298,
+ 299,
+ 300,
+ 301,
+ 302,
+ 303,
+ 304,
+ 305,
+ 306,
+ 307,
+ 308,
+ 309,
+ 310,
+ 311,
+ 312,
+ 313,
+ 314,
+ 315,
+ 316,
+ 317,
+ 318,
+ 319,
+ 320,
+ 321,
+ 322,
+ 323,
+ 324,
+ 325,
+ 327,
+ 328,
+ 329,
+ 330,
+ 331,
+ 332,
+ 333,
+ 334,
+ 335,
+ 336,
+ 337,
+ 338,
+ 339,
+ 340,
+ 341,
+ 342,
+ 343,
+ 344,
+ 345,
+ 346,
+ 347,
+ 348,
+ 349,
+ 350,
+ 351,
+ 352,
+ 353,
+ 354,
+ 355,
+ 356,
+ 357,
+ 358,
+ 359,
+ 360,
+ 361,
+ 362,
+ 363,
+ 364,
+ 365,
+ 366,
+ 367,
+ 368,
+ 369,
+ 370,
+ 371,
+ 372,
+ 373,
+ 374,
+ 375,
+ 376,
+ 377,
+ 378,
+ 379,
+ 380,
+ 381,
+ 382,
+ 383,
+ 384,
+ 385,
+ 386,
+ 387,
+ 388,
+ 389,
+ 390,
+ 391,
+ 392,
+ 393,
+ 394,
+ 396,
+ 397,
+ 398,
+ 399,
+ 400,
+ 401,
+ 402,
+ 403,
+ 404,
+ 405,
+ 406,
+ 407,
+ 408,
+ 409,
+ 410,
+ 411,
+ 412,
+ 413,
+ 414,
+ 415,
+ 416,
+ 417,
+ 418,
+ 419,
+ 420,
+ 421,
+ 422,
+ 423,
+ 424,
+ 425,
+ 426,
+ 427,
+ 428,
+ 429,
+ 430,
+ 431,
+ 432,
+ 433,
+ 434,
+ 435,
+ 436,
+ 437,
+ 438,
+ 439,
+ 440,
+ 441,
+ 442,
+ 443,
+ 444,
+ 445,
+ 446,
+ 447,
+ 448,
+ 449,
+ 450,
+ 451,
+ 453,
+ 454,
+ 455,
+ 456,
+ 457,
+ 458,
+ 459,
+ 460,
+ 461,
+ 462,
+ 463,
+ 464,
+ 465,
+ 466,
+ 467,
+ 468,
+ 469,
+ 470,
+ 471,
+ 472,
+ 473,
+ 474,
+ 475,
+ 476,
+ 477,
+ 478,
+ 479,
+ 480,
+ 481,
+ 482,
+ 483,
+ 484,
+ 485,
+ 486,
+ 487,
+ 488,
+ 489,
+ 490,
+ 491,
+ 492,
+ 493,
+ 494,
+ 495,
+ 496,
+ 497,
+ 498,
+ 499,
+ 500,
+ 501,
+ 502,
+ 503,
+ 504,
+ 505,
+ 506,
+ 507,
+ 508,
+ 509,
+ 510,
+ 511,
+ 512,
+ 513,
+ 514,
+ 515,
+ 516,
+ 517,
+ 518,
+ 519,
+ 520,
+ 521,
+ 522,
+ 523,
+ 524,
+ 525,
+ 526,
+ 527,
+ 528,
+ 529,
+ 530,
+ 531,
+ 532,
+ 533,
+ 534,
+ 535,
+ 536,
+ 537,
+ 538,
+ 539,
+ 540,
+ 541,
+ 542,
+ 543,
+ 544,
+ 545,
+ 546,
+ 547,
+ 548,
+ 549,
+ 550,
+ 551,
+ 552,
+ 553,
+ 554,
+ 555,
+ 556,
+ 557,
+ 558,
+ 559,
+ 560,
+ 561,
+ 562,
+ 563,
+ 564,
+ 565,
+ 566,
+ 567,
+ 568,
+ 569,
+ 570,
+ 571,
+ 572,
+ 573,
+ 574,
+ 575,
+ 576,
+ 577,
+ 579,
+ 580,
+ 581,
+ 582,
+ 583,
+ 584,
+ 585,
+ 586,
+ 587,
+ 588,
+ 589,
+ 590,
+ 591,
+ 592,
+ 593,
+ 594,
+ 595,
+ 596,
+ 597,
+ 598,
+ 599,
+ 600,
+ 601,
+ 602,
+ 603,
+ 604,
+ 605,
+ 606,
+ 607,
+ 608,
+ 609,
+ 610,
+ 611,
+ 612,
+ 613,
+ 614,
+ 615,
+ 616,
+ 617,
+ 618,
+ 619,
+ 620,
+ 621,
+ 622,
+ 623,
+ 624,
+ 625,
+ 626,
+ 627,
+ 628,
+ 629,
+ 630,
+ 631,
+ 632,
+ 633,
+ 634,
+ 635,
+ 636,
+ 637,
+ 638,
+ 639,
+ 640,
+ 641,
+ 642,
+ 643,
+ 644,
+ 645,
+ 646,
+ 647,
+ 648,
+ 649,
+ 650,
+ 651,
+ 652,
+ 653,
+ 654,
+ 655,
+ 656,
+ 657,
+ 658,
+ 659,
+ 660,
+ 661,
+ 662,
+ 663,
+ 664,
+ 665,
+ 666,
+ 667,
+ 668,
+ 669,
+ 670,
+ 671,
+ 672,
+ 673,
+ 674,
+ 675,
+ 676,
+ 677,
+ 678,
+ 679,
+ 680,
+ 681,
+ 682,
+ 683,
+ 684,
+ 685,
+ 686,
+ 687,
+ 688,
+ 689,
+ 690,
+ 691,
+ 692,
+ 693,
+ 694,
+ 695,
+ 696,
+ 697,
+ 698,
+ 699,
+ 700,
+ 701,
+ 702,
+ 703,
+ 704,
+ 705,
+ 706,
+ 707,
+ 708,
+ 709,
+ 710,
+ 711,
+ 712,
+ 713,
+ 714,
+ 715,
+ 716,
+ 717,
+ 718,
+ 719,
+ 720,
+ 721,
+ 722,
+ 723,
+ 724,
+ 725,
+ 726,
+ 727,
+ 728,
+ 729,
+ 730,
+ 731,
+ 732,
+ 733,
+ 734,
+ 735,
+ 736,
+ 737,
+ 738,
+ 739,
+ 740,
+ 741,
+ 742,
+ 743,
+ 744,
+ 745,
+ 746,
+ 747,
+ 748,
+ 749,
+ 750,
+ 751,
+ 752,
+ 753,
+ 754,
+ 755,
+ 756,
+ 757,
+ 758,
+ 759,
+ 760,
+ 761,
+ 762,
+ 763,
+ 764,
+ 765,
+ 766,
+ 767,
+ 768,
+ 769,
+ 770,
+ 771,
+ 772,
+ 773,
+ 774,
+ 775,
+ 776,
+ 777,
+ 778,
+ 779,
+ 780,
+ 781,
+ 782,
+ 783,
+ 784,
+ 785,
+ 786,
+ 787,
+ 788,
+ 789,
+ 790,
+ 791,
+ 792,
+ 793,
+ 794,
+ 795,
+ 796,
+ 797,
+ 798,
+ 799,
+ 800,
+ 801,
+ 802,
+ 803,
+ 804,
+ 805,
+ 806,
+ 807,
+ 808,
+ 809,
+ 810,
+ 811,
+ 812,
+ 813,
+ 814,
+ 815,
+ 816,
+ 817,
+ 818,
+ 819,
+ 820,
+ 821,
+ 822,
+ 823,
+ 824,
+ 825,
+ 826,
+ 827,
+ 828,
+ 829,
+ 830,
+ 831,
+ 832,
+ 833,
+ 834,
+ 835,
+ 836,
+ 837,
+ 838,
+ 839,
+ 840,
+ 841,
+ 842,
+ 843,
+ 844,
+ 845,
+ 846,
+ 847,
+ 848,
+ 849,
+ 850,
+ 851,
+ 852,
+ 853,
+ 854,
+ 855,
+ 856,
+ 857,
+ 858,
+ 859,
+ 860,
+ 861,
+ 862,
+ 863,
+ 864,
+ 865,
+ 866,
+ 867,
+ 868,
+ 869,
+ 870,
+ 871,
+ 872,
+ 873,
+ 874,
+ 875,
+ 876,
+ 877,
+ 878,
+ 879,
+ 880,
+ 881,
+ 882,
+ 883,
+ 884,
+ 885,
+ 886,
+ 887,
+ 888,
+ 889,
+ 890,
+ 891,
+ 892,
+ 893,
+ 894,
+ 895,
+ 896,
+ 897,
+ 898,
+ 899,
+ 900,
+ 901,
+ 902,
+ 903,
+ 904,
+ 905,
+ 906,
+ 907,
+ 908,
+ 909,
+ 911,
+ 912,
+ 913,
+ 914,
+ 915,
+ 916,
+ 917,
+ 918,
+ 919,
+ 920,
+ 921,
+ 922,
+ 923,
+ 924,
+ 925,
+ 926,
+ 927,
+ 928,
+ 929,
+ 930,
+ 931,
+ 932,
+ 933,
+ 934,
+ 935,
+ 936,
+ 937,
+ 938,
+ 939,
+ 940,
+ 941,
+ 942,
+ 943,
+ 944,
+ 945,
+ 946,
+ 947,
+ 948,
+ 949,
+ 950,
+ 951,
+ 952,
+ 953,
+ 954,
+ 955,
+ 956,
+ 957,
+ 958,
+ 960,
+ 961,
+ 962,
+ 963,
+ 965,
+ 966,
+ 967,
+ 968,
+ 969,
+ 970,
+ 971,
+ 972,
+ 973,
+ 974,
+ 978,
+ 979,
+ 980,
+ 981,
+ 982,
+ 983,
+ 984,
+ 985,
+ 986,
+ 987,
+ 988,
+ 989,
+ 990,
+ 991,
+ 992,
+ 993,
+ 994,
+ 995,
+ 996,
+ 997,
+ 998,
+ 999,
+ 1000,
+ 1001,
+ 1002,
+ 1003,
+ 1004,
+ 1005,
+ 1006,
+ 1007,
+ 1008,
+ 1009,
+ 1010,
+ 1011,
+ 1012,
+ 1013,
+ 1014,
+ 1015,
+ 1016,
+ 1017,
+ 1018,
+ 1019,
+ 1020,
+ 1021,
+ 1022,
+ 1023,
+ 1024,
+ 1025,
+ 1026,
+ 1027,
+ 1028,
+ 1029,
+ 1030,
+ 1031,
+ 1032,
+ 1033,
+ 1034,
+ 1035,
+ 1036,
+ 1037,
+ 1038,
+ 1039,
+ 1040,
+ 1041,
+ 1042,
+ 1043,
+ 1044,
+ 1045,
+ 1046,
+ 1047,
+ 1048,
+ 1049,
+ 1050,
+ 1051,
+ 1052,
+ 1053,
+ 1054,
+ 1055,
+ 1056,
+ 1057,
+ 1058,
+ 1059,
+ 1060,
+ 1061,
+ 1062,
+ 1063,
+ 1064,
+ 1065,
+ 1066,
+ 1067,
+ 1068,
+ 1069,
+ 1070,
+ 1071,
+ 1072,
+ 1073,
+ 1074,
+ 1075,
+ 1076,
+ 1077,
+ 1078,
+ 1079,
+ 1080,
+ 1081,
+ 1082,
+ 1083,
+ 1084,
+ 1085,
+ 1086,
+ 1087,
+ 1088,
+ 1089,
+ 1090,
+ 1091,
+ 1092,
+ 1093,
+ 1094,
+ 1095,
+ 1096,
+ 1097,
+ 1098,
+ 1099,
+ 1100,
+ 1101,
+ 1102,
+ 1103,
+ 1104,
+ 1105,
+ 1106,
+ 1107,
+ 1108,
+ 1109,
+ 1110,
+ 1111,
+ 1112,
+ 1113,
+ 1114,
+ 1115,
+ 1116,
+ 1117,
+ 1118,
+ 1119,
+ 1120,
+ 1121,
+ 1122,
+ 1123,
+ 1124,
+ 1125,
+ 1126,
+ 1127,
+ 1128,
+ 1129,
+ 1130,
+ 1131,
+ 1132,
+ 1133,
+ 1134,
+ 1135,
+ 1136,
+ 1137,
+ 1138,
+ 1139,
+ 1140,
+ 1141,
+ 1142,
+ 1143,
+ 1144,
+ 1145,
+ 1146,
+ 1147,
+ 1148,
+ 1149,
+ 1150,
+ 1151,
+ 1152,
+ 1153,
+ 1154,
+ 1155,
+ 1156,
+ 1157,
+ 1158,
+ 1159,
+ 1160,
+ 1161,
+ 1162,
+ 1163,
+ 1164,
+ 1165,
+ 1166,
+ 1167,
+ 1168,
+ 1169,
+ 1170,
+ 1171,
+ 1173,
+ 1174,
+ 1177,
+ 1179,
+ 1180,
+ 1181,
+ 1182,
+ 1183,
+ 1184,
+ 1185,
+ 1186,
+ 1187,
+ 1188,
+ 1189,
+ 1190,
+ 1191,
+ 1192,
+ 1193,
+ 1195,
+ 1196,
+ 1197,
+ 1198,
+ 1199,
+ 1200,
+ 1201,
+ 1202,
+ 1203,
+ 1204,
+ 1205,
+ 1206,
+ 1207,
+ 1208,
+ 1209,
+ 1210,
+ 1211,
+ 1212,
+ 1213,
+ 1214,
+ 1215,
+ 1216,
+ 1217,
+ 1218,
+ 1219,
+ 1220,
+ 1221,
+ 1222,
+ 1223,
+ 1224,
+ 1225,
+ 1226,
+ 1227,
+ 1228,
+ 1229,
+ 1230,
+ 1231,
+ 1232,
+ 1233,
+ 1234,
+ 1235,
+ 1236,
+ 1237,
+ 1238,
+ 1239,
+ 1240,
+ 1241,
+ 1242,
+ 1244,
+ 1245,
+ 1246,
+ 1247,
+ 1248,
+ 1249,
+ 1250,
+ 1251,
+ 1252,
+ 1253,
+ 1254,
+ 1255,
+ 1256,
+ 1257,
+ 1258,
+ 1259,
+ 1260,
+ 1261,
+ 1262,
+ 1263,
+ 1264,
+ 1265,
+ 1266,
+ 1267,
+ 1268,
+ 1269,
+ 1270,
+ 1271,
+ 1272,
+ 1273,
+ 1274,
+ 1275,
+ 1276,
+ 1277,
+ 1278,
+ 1279,
+ 1280,
+ 1281,
+ 1282,
+ 1283,
+ 1284,
+ 1285,
+ 1286,
+ 1287,
+ 1288,
+ 1289,
+ 1290,
+ 1291,
+ 1292,
+ 1293,
+ 1294,
+ 1295,
+ 1296,
+ 1297,
+ 1298,
+ 1299,
+ 1300,
+ 1301,
+ 1302,
+ 1303,
+ 1304,
+ 1305,
+ 1306,
+ 1307,
+ 1308,
+ 1309,
+ 1310,
+ 1311,
+ 1312,
+ 1313,
+ 1314,
+ 1315,
+ 1316,
+ 1317,
+ 1318,
+ 1319,
+ 1320,
+ 1321,
+ 1322,
+ 1323,
+ 1324,
+ 1325,
+ 1326,
+ 1327,
+ 1328,
+ 1329,
+ 1330,
+ 1331,
+ 1332,
+ 1333,
+ 1334,
+ 1335,
+ 1336,
+ 1337,
+ 1338,
+ 1339,
+ 1340,
+ 1341,
+ 1342,
+ 1343,
+ 1344,
+ 1345,
+ 1346,
+ 1347,
+ 1348,
+ 1349,
+ 1350,
+ 1351,
+ 1352,
+ 1353,
+ 1354,
+ 1355,
+ 1356,
+ 1357,
+ 1358,
+ 1359,
+ 1360,
+ 1361,
+ 1362,
+ 1363,
+ 1364,
+ 1365,
+ 1366,
+ 1367,
+ 1368,
+ 1369,
+ 1370,
+ 1371,
+ 1372,
+ 1373,
+ 1374,
+ 1375,
+ 1376,
+ 1377,
+ 1378,
+ 1379,
+ 1380,
+ 1381,
+ 1382,
+ 1383,
+ 1384,
+ 1385,
+ 1386,
+ 1387,
+ 1388,
+ 1389,
+ 1390,
+ 1391,
+ 1392,
+ 1393,
+ 1394,
+ 1395,
+ 1396,
+ 1397,
+ 1398,
+ 1399,
+ 1400,
+ 1401,
+ 1402,
+ 1403,
+ 1404,
+ 1405,
+ 1406,
+ 1407,
+ 1408,
+ 1409,
+ 1410,
+ 1411,
+ 1412,
+ 1413,
+ 1414,
+ 1415,
+ 1416,
+ 1417,
+ 1418,
+ 1419,
+ 1420,
+ 1421,
+ 1422,
+ 1423,
+ 1424,
+ 1425,
+ 1426,
+ 1427,
+ 1428,
+ 1429,
+ 1430,
+ 1431,
+ 1432,
+ 1433,
+ 1434,
+ 1435,
+ 1436,
+ 1437,
+ 1438,
+ 1439,
+ 1440,
+ 1441,
+ 1442,
+ 1443,
+ 1444,
+ 1445,
+ 1446,
+ 1447,
+ 1448,
+ 1449,
+ 1450,
+ 1451,
+ 1452,
+ 1453,
+ 1454,
+ 1455,
+ 1456,
+ 1457,
+ 1458,
+ 1459,
+ 1460,
+ 1461,
+ 1462,
+ 1463,
+ 1464,
+ 1465,
+ 1466,
+ 1467,
+ 1468,
+ 1469,
+ 1470,
+ 1471,
+ 1472,
+ 1473,
+ 1474,
+ 1475,
+ 1476,
+ 1477,
+ 1478,
+ 1479,
+ 1480,
+ 1481,
+ 1482,
+ 1483,
+ 1484,
+ 1485,
+ 1486,
+ 1487,
+ 1488,
+ 1489,
+ 1490,
+ 1491,
+ 1492,
+ 1493,
+ 1494,
+ 1495,
+ 1496,
+ 1497,
+ 1498,
+ 1499,
+ 1500,
+ 1501,
+ 1502,
+ 1503,
+ 1504,
+ 1505,
+ 1506,
+ 1507,
+ 1508,
+ 1509,
+ 1510,
+ 1511,
+ 1512,
+ 1513,
+ 1514,
+ 1515,
+ 1516,
+ 1517,
+ 1518,
+ 1519,
+ 1520,
+ 1521,
+ 1522,
+ 1523,
+ 1524,
+ 1525,
+ 1526,
+ 1527,
+ 1528,
+ 1529,
+ 1530,
+ 1531,
+ 1532,
+ 1533,
+ 1534,
+ 1535,
+ 1536,
+ 1537,
+ 1538,
+ 1539,
+ 1540,
+ 1541,
+ 1542,
+ 1543,
+ 1544,
+ 1545,
+ 1546,
+ 1547,
+ 1548,
+ 1549,
+ 1550,
+ 1551,
+ 1552,
+ 1553,
+ 1554,
+ 1555,
+ 1556,
+ 1557,
+ 1558,
+ 1559,
+ 1560,
+ 1561,
+ 1562,
+ 1563,
+ 1564,
+ 1565,
+ 1566,
+ 1567,
+ 1568,
+ 1569,
+ 1570,
+ 1571,
+ 1572,
+ 1573,
+ 1574,
+ 1575,
+ 1576,
+ 1577,
+ 1578,
+ 1579,
+ 1580,
+ 1581,
+ 1582,
+ 1583,
+ 1584,
+ 1585,
+ 1586,
+ 1587,
+ 1588,
+ 1589,
+ 1590,
+ 1591,
+ 1592,
+ 1593,
+ 1594,
+ 1595,
+ 1596,
+ 1597,
+ 1598,
+ 1599,
+ 1600,
+ 1601,
+ 1602,
+ 1603,
+ 1604,
+ 1605,
+ 1606,
+ 1607,
+ 1608,
+ 1609,
+ 1610,
+ 1611,
+ 1612,
+ 1613,
+ 1614,
+ 1615,
+ 1616,
+ 1617,
+ 1618,
+ 1619,
+ 1620,
+ 1621,
+ 1622,
+ 1623,
+ 1624,
+ 1625,
+ 1626,
+ 1627,
+ 1628,
+ 1629,
+ 1630,
+ 1631,
+ 1632,
+ 1633,
+ 1634,
+ 1635,
+ 1636,
+ 1637,
+ 1638,
+ 1639,
+ 1640,
+ 1641,
+ 1642,
+ 1643,
+ 1644,
+ 1645,
+ 1646,
+ 1647,
+ 1648,
+ 1649,
+ 1650,
+ 1651,
+ 1652,
+ 1653,
+ 1654,
+ 1655,
+ 1656,
+ 1657,
+ 1658,
+ 1659,
+ 1660,
+ 1661,
+ 1662,
+ 1663,
+ 1664,
+ 1665,
+ 1666,
+ 1667,
+ 1668,
+ 1669,
+ 1670,
+ 1671,
+ 1672,
+ 1673,
+ 1674,
+ 1675,
+ 1676,
+ 1677,
+ 1678,
+ 1679,
+ 1680,
+ 1681,
+ 1682,
+ 1683,
+ 1684,
+ 1685,
+ 1686,
+ 1687,
+ 1688,
+ 1689,
+ 1690,
+ 1691,
+ 1692,
+ 1693,
+ 1694,
+ 1695,
+ 1696,
+ 1697,
+ 1698,
+ 1699,
+ 1700,
+ 1701,
+ 1702,
+ 1703,
+ 1704,
+ 1705,
+ 1706,
+ 1707,
+ 1708,
+ 1709,
+ 1710,
+ 1711,
+ 1712,
+ 1713,
+ 1714,
+ 1715,
+ 1716,
+ 1717,
+ 1718,
+ 1719,
+ 1720,
+ 1721,
+ 1722,
+ 1723,
+ 1724,
+ 1725,
+ 1726,
+ 1727,
+ 1728,
+ 1729,
+ 1730,
+ 1731,
+ 1732,
+ 1733,
+ 1734,
+ 1735,
+ 1736,
+ 1737,
+ 1738,
+ 1740,
+ 1741,
+ 1742,
+ 1743,
+ 1744,
+ 1745,
+ 1746,
+ 1747,
+ 1748,
+ 1749,
+ 1750,
+ 1751,
+ 1752,
+ 1753,
+ 1754,
+ 1755,
+ 1756,
+ 1757,
+ 1758,
+ 1759,
+ 1760,
+ 1761,
+ 1762,
+ 1763,
+ 1764,
+ 1765,
+ 1766,
+ 1767,
+ 1768,
+ 1769,
+ 1770,
+ 1771,
+ 1772,
+ 1773,
+ 1774,
+ 1775,
+ 1776,
+ 1777,
+ 1778,
+ 1779,
+ 1780,
+ 1781,
+ 1782,
+ 1783,
+ 1784,
+ 1785,
+ 1786,
+ 1787,
+ 1788,
+ 1789,
+ 1790,
+ 1791,
+ 1792,
+ 1793,
+ 1794,
+ 1795,
+ 1841,
+ 1864,
+ 1865,
+ 1866,
+ 1867,
+ 1868,
+ 1869,
+ 1871,
+ 1872,
+ 1873,
+ 1874,
+ 1875,
+ 1876,
+ 1877,
+ 1878,
+ 1879,
+ 1880,
+ 1881,
+ 1882,
+ 1883,
+ 1884,
+ 1885,
+ 1886,
+ 1887,
+ 1888,
+ 1889,
+ 1890,
+ 1891,
+ 1892,
+ 1893,
+ 1894,
+ 1895,
+ 1896,
+ 1897,
+ 1898,
+ 1899,
+ 1900,
+ 1901,
+ 1902,
+ 1903,
+ 1904,
+ 1905,
+ 1906,
+ 1907,
+ 1908,
+ 1909,
+ 1910,
+ 1911,
+ 1912,
+ 1913,
+ 1914,
+ 1915,
+ 1916,
+ 1917,
+ 1918,
+ 1920,
+ 1921,
+ 1922,
+ 1923,
+ 1924,
+ 1925,
+ 1926,
+ 1927,
+ 1928,
+ 1929,
+ 1930,
+ 1931,
+ 1932,
+ 1933,
+ 1934,
+ 1935,
+ 1936,
+ 1937,
+ 1938,
+ 1939,
+ 1940,
+ 1941,
+ 1942,
+ 1943,
+ 1944,
+ 1945,
+ 1946,
+ 1947,
+ 1948,
+ 1949,
+ 1950,
+ 1951,
+ 1952,
+ 1953,
+ 1954,
+ 1955,
+ 1956,
+ 1957,
+ 1958,
+ 1959,
+ 1962,
+ 1963,
+ 1964,
+ 1966,
+ 1968,
+ 1969,
+ 1970,
+ 1971,
+ 1972,
+ 1973,
+ 1974,
+ 1975,
+ 1976,
+ 1977,
+ 1978,
+ 1979,
+ 1980,
+ 1981,
+ 1982,
+ 1983,
+ 1984,
+ 1985,
+ 1986,
+ 1987,
+ 1988,
+ 1989,
+ 1990,
+ 1991,
+ 1992,
+ 1993,
+ 1994,
+ 1995,
+ 1996,
+ 1997,
+ 1998,
+ 1999,
+ 2000,
+ 2001,
+ 2002,
+ 2004,
+ 2005,
+ 2006,
+ 2007,
+ 2008,
+ 2009,
+ 2010,
+ 2011
+ ]
+}
\ No newline at end of file
diff --git a/data/statistics/statistics_modified_v1.json b/data/statistics/statistics_modified_v1.json
new file mode 100644
index 0000000000000000000000000000000000000000..2eb46afa77a3dd46b8c8c41eb97939b606001550
--- /dev/null
+++ b/data/statistics/statistics_modified_v1.json
@@ -0,0 +1,615 @@
+{
+ "trans_mean": [
+ 0.02,
+ 0.0,
+ 14.79
+ ],
+ "trans_std": [
+ 0.10,
+ 0.10,
+ 2.65
+ ],
+ "flength_mean": [
+ 2169.0
+ ],
+ "flength_std": [
+ 448.0
+ ],
+ "pose_mean": [
+ [
+ [
+ 0.44,
+ 0.0,
+ -0.0
+ ],
+ [
+ 0.0,
+ 0.0,
+ -1.0
+ ],
+ [
+ -0.0,
+ 0.44,
+ 0.0
+ ]
+ ],
+ [
+ [
+ 0.97,
+ -0.0,
+ -0.08
+ ],
+ [
+ 0.0,
+ 0.98,
+ 0.0
+ ],
+ [
+ 0.08,
+ -0.0,
+ 0.98
+ ]
+ ],
+ [
+ [
+ 0.98,
+ 0.0,
+ 0.01
+ ],
+ [
+ -0.0,
+ 0.99,
+ 0.0
+ ],
+ [
+ -0.01,
+ 0.0,
+ 0.98
+ ]
+ ],
+ [
+ [
+ 0.98,
+ -0.0,
+ -0.03
+ ],
+ [
+ 0.0,
+ 0.99,
+ 0.0
+ ],
+ [
+ 0.04,
+ -0.0,
+ 0.98
+ ]
+ ],
+ [
+ [
+ 0.98,
+ 0.0,
+ 0.02
+ ],
+ [
+ -0.0,
+ 0.99,
+ -0.0
+ ],
+ [
+ -0.02,
+ -0.0,
+ 0.98
+ ]
+ ],
+ [
+ [
+ 0.99,
+ 0.0,
+ -0.0
+ ],
+ [
+ -0.0,
+ 0.99,
+ -0.0
+ ],
+ [
+ 0.0,
+ 0.0,
+ 0.99
+ ]
+ ],
+ [
+ [
+ 0.99,
+ 0.0,
+ 0.03
+ ],
+ [
+ 0.0,
+ 0.99,
+ -0.0
+ ],
+ [
+ -0.03,
+ 0.0,
+ 0.99
+ ]
+ ],
+ [
+ [
+ 0.95,
+ -0.05,
+ 0.04
+ ],
+ [
+ 0.05,
+ 0.98,
+ -0.01
+ ],
+ [
+ -0.03,
+ 0.01,
+ 0.96
+ ]
+ ],
+ [
+ [
+ 0.91,
+ -0.01,
+ -0.19
+ ],
+ [
+ -0.01,
+ 0.98,
+ -0.05
+ ],
+ [
+ 0.19,
+ 0.03,
+ 0.91
+ ]
+ ],
+ [
+ [
+ 0.85,
+ -0.04,
+ 0.23
+ ],
+ [
+ -0.0,
+ 0.99,
+ 0.07
+ ],
+ [
+ -0.23,
+ -0.06,
+ 0.85
+ ]
+ ],
+ [
+ [
+ 0.93,
+ 0.0,
+ 0.16
+ ],
+ [
+ -0.01,
+ 0.99,
+ 0.01
+ ],
+ [
+ -0.16,
+ -0.02,
+ 0.93
+ ]
+ ],
+ [
+ [
+ 0.95,
+ 0.05,
+ 0.03
+ ],
+ [
+ -0.05,
+ 0.98,
+ 0.02
+ ],
+ [
+ -0.03,
+ -0.01,
+ 0.96
+ ]
+ ],
+ [
+ [
+ 0.91,
+ 0.01,
+ -0.19
+ ],
+ [
+ 0.02,
+ 0.98,
+ 0.05
+ ],
+ [
+ 0.2,
+ -0.03,
+ 0.91
+ ]
+ ],
+ [
+ [
+ 0.84,
+ 0.03,
+ 0.24
+ ],
+ [
+ 0.01,
+ 0.99,
+ -0.06
+ ],
+ [
+ -0.24,
+ 0.07,
+ 0.84
+ ]
+ ],
+ [
+ [
+ 0.93,
+ -0.0,
+ 0.18
+ ],
+ [
+ 0.01,
+ 0.99,
+ -0.01
+ ],
+ [
+ -0.18,
+ 0.02,
+ 0.93
+ ]
+ ],
+ [
+ [
+ 0.95,
+ -0.0,
+ 0.01
+ ],
+ [
+ -0.0,
+ 0.96,
+ 0.0
+ ],
+ [
+ -0.0,
+ -0.0,
+ 0.99
+ ]
+ ],
+ [
+ [
+ 0.93,
+ 0.0,
+ -0.11
+ ],
+ [
+ -0.0,
+ 0.97,
+ -0.0
+ ],
+ [
+ 0.12,
+ 0.0,
+ 0.95
+ ]
+ ],
+ [
+ [
+ 0.96,
+ -0.04,
+ -0.06
+ ],
+ [
+ 0.03,
+ 0.98,
+ -0.02
+ ],
+ [
+ 0.06,
+ 0.01,
+ 0.96
+ ]
+ ],
+ [
+ [
+ 0.96,
+ 0.05,
+ 0.04
+ ],
+ [
+ -0.05,
+ 0.98,
+ -0.05
+ ],
+ [
+ -0.05,
+ 0.05,
+ 0.96
+ ]
+ ],
+ [
+ [
+ 0.96,
+ -0.0,
+ -0.09
+ ],
+ [
+ -0.0,
+ 0.99,
+ 0.01
+ ],
+ [
+ 0.09,
+ -0.01,
+ 0.96
+ ]
+ ],
+ [
+ [
+ 0.96,
+ 0.0,
+ 0.06
+ ],
+ [
+ -0.02,
+ 0.98,
+ 0.05
+ ],
+ [
+ -0.05,
+ -0.06,
+ 0.96
+ ]
+ ],
+ [
+ [
+ 0.96,
+ 0.04,
+ -0.07
+ ],
+ [
+ -0.03,
+ 0.98,
+ 0.02
+ ],
+ [
+ 0.07,
+ -0.01,
+ 0.96
+ ]
+ ],
+ [
+ [
+ 0.96,
+ -0.05,
+ 0.04
+ ],
+ [
+ 0.04,
+ 0.98,
+ 0.05
+ ],
+ [
+ -0.05,
+ -0.04,
+ 0.97
+ ]
+ ],
+ [
+ [
+ 0.96,
+ -0.0,
+ -0.09
+ ],
+ [
+ 0.0,
+ 0.99,
+ -0.01
+ ],
+ [
+ 0.09,
+ 0.01,
+ 0.96
+ ]
+ ],
+ [
+ [
+ 0.96,
+ -0.0,
+ 0.06
+ ],
+ [
+ 0.02,
+ 0.98,
+ -0.05
+ ],
+ [
+ -0.05,
+ 0.06,
+ 0.96
+ ]
+ ],
+ [
+ [
+ 0.73,
+ 0.0,
+ -0.4
+ ],
+ [
+ -0.0,
+ 0.98,
+ 0.0
+ ],
+ [
+ 0.39,
+ 0.0,
+ 0.73
+ ]
+ ],
+ [
+ [
+ 0.95,
+ -0.0,
+ -0.07
+ ],
+ [
+ 0.0,
+ 0.99,
+ -0.0
+ ],
+ [
+ 0.07,
+ 0.0,
+ 0.95
+ ]
+ ],
+ [
+ [
+ 0.98,
+ 0.0,
+ -0.09
+ ],
+ [
+ -0.0,
+ 0.99,
+ -0.0
+ ],
+ [
+ 0.09,
+ 0.0,
+ 0.98
+ ]
+ ],
+ [
+ [
+ 0.99,
+ -0.0,
+ 0.03
+ ],
+ [
+ 0.0,
+ 0.99,
+ -0.0
+ ],
+ [
+ -0.03,
+ 0.0,
+ 0.99
+ ]
+ ],
+ [
+ [
+ 0.96,
+ -0.0,
+ 0.1
+ ],
+ [
+ 0.0,
+ 0.98,
+ -0.0
+ ],
+ [
+ -0.09,
+ 0.0,
+ 0.96
+ ]
+ ],
+ [
+ [
+ 0.79,
+ -0.01,
+ 0.21
+ ],
+ [
+ 0.01,
+ 0.96,
+ 0.0
+ ],
+ [
+ -0.2,
+ 0.0,
+ 0.82
+ ]
+ ],
+ [
+ [
+ 0.89,
+ -0.0,
+ 0.07
+ ],
+ [
+ 0.0,
+ 0.98,
+ 0.0
+ ],
+ [
+ -0.07,
+ 0.0,
+ 0.9
+ ]
+ ],
+ [
+ [
+ 0.96,
+ -0.0,
+ 0.09
+ ],
+ [
+ 0.0,
+ 0.99,
+ -0.0
+ ],
+ [
+ -0.1,
+ 0.0,
+ 0.96
+ ]
+ ],
+ [
+ [
+ 0.93,
+ -0.09,
+ -0.07
+ ],
+ [
+ 0.1,
+ 0.93,
+ 0.03
+ ],
+ [
+ 0.03,
+ -0.06,
+ 0.95
+ ]
+ ],
+ [
+ [
+ 0.86,
+ 0.1,
+ -0.37
+ ],
+ [
+ -0.12,
+ 0.94,
+ 0.01
+ ],
+ [
+ 0.35,
+ 0.05,
+ 0.88
+ ]
+ ]
+ ]
+}
diff --git a/datasets/test_image_crops/201030094143-stock-rhodesian-ridgeback-super-tease.jpg b/datasets/test_image_crops/201030094143-stock-rhodesian-ridgeback-super-tease.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c3120182f24a6246e2fca3fdccbb905a30b70cc8
Binary files /dev/null and b/datasets/test_image_crops/201030094143-stock-rhodesian-ridgeback-super-tease.jpg differ
diff --git a/datasets/test_image_crops/Akita-standing-outdoors-in-the-summer-400x267.jpg b/datasets/test_image_crops/Akita-standing-outdoors-in-the-summer-400x267.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..73092b851ac8b555406365a34ee91a8d8069b9bc
Binary files /dev/null and b/datasets/test_image_crops/Akita-standing-outdoors-in-the-summer-400x267.jpg differ
diff --git a/datasets/test_image_crops/image_n02089078-black-and-tan_coonhound_n02089078_3810.png b/datasets/test_image_crops/image_n02089078-black-and-tan_coonhound_n02089078_3810.png
new file mode 100644
index 0000000000000000000000000000000000000000..16b538f73b30ba8bd00816a9804803c62b12ad6a
Binary files /dev/null and b/datasets/test_image_crops/image_n02089078-black-and-tan_coonhound_n02089078_3810.png differ
diff --git a/gradio_demo/barc_demo_v3.py b/gradio_demo/barc_demo_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..84c8b11541419092853bdbaf9b96b2b406e1dc8a
--- /dev/null
+++ b/gradio_demo/barc_demo_v3.py
@@ -0,0 +1,268 @@
+# python gradio_demo/barc_demo_v3.py
+
+import numpy as np
+import os
+import glob
+import torch
+from torch.utils.data import DataLoader
+import torchvision
+from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
+import torchvision.transforms as T
+import cv2
+from matplotlib import pyplot as plt
+from PIL import Image
+
+import gradio as gr
+
+
+
+import sys
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../', 'src'))
+from stacked_hourglass.datasets.imgcropslist import ImgCrops
+from combined_model.train_main_image_to_3d_withbreedrel import do_visual_epoch
+from combined_model.model_shape_v7 import ModelImageTo3d_withshape_withproj
+
+from configs.barc_cfg_defaults import get_cfg_global_updated
+
+
+
+def get_prediction(model, img_path_or_img, confidence=0.5):
+ """
+ see https://haochen23.github.io/2020/04/object-detection-faster-rcnn.html#.YsMCm4TP3-g
+ get_prediction
+ parameters:
+ - img_path - path of the input image
+ - confidence - threshold value for prediction score
+ method:
+ - Image is obtained from the image path
+ - the image is converted to image tensor using PyTorch's Transforms
+ - image is passed through the model to get the predictions
+ - class, box coordinates are obtained, but only prediction score > threshold
+ are chosen.
+
+ """
+ if isinstance(img_path_or_img, str):
+ img = Image.open(img_path_or_img).convert('RGB')
+ else:
+ img = img_path_or_img
+ transform = T.Compose([T.ToTensor()])
+ img = transform(img)
+ pred = model([img])
+ # pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
+ pred_class = list(pred[0]['labels'].numpy())
+ pred_boxes = [[(int(i[0]), int(i[1])), (int(i[2]), int(i[3]))] for i in list(pred[0]['boxes'].detach().numpy())]
+ pred_score = list(pred[0]['scores'].detach().numpy())
+ try:
+ pred_t = [pred_score.index(x) for x in pred_score if x>confidence][-1]
+ pred_boxes = pred_boxes[:pred_t+1]
+ pred_class = pred_class[:pred_t+1]
+ return pred_boxes, pred_class, pred_score
+ except:
+ print('no bounding box with a score that is high enough found! -> work on full image')
+ return None, None, None
+
+def detect_object(model, img_path_or_img, confidence=0.5, rect_th=2, text_size=0.5, text_th=1):
+ """
+ see https://haochen23.github.io/2020/04/object-detection-faster-rcnn.html#.YsMCm4TP3-g
+ object_detection_api
+ parameters:
+ - img_path_or_img - path of the input image
+ - confidence - threshold value for prediction score
+ - rect_th - thickness of bounding box
+ - text_size - size of the class label text
+ - text_th - thichness of the text
+ method:
+ - prediction is obtained from get_prediction method
+ - for each prediction, bounding box is drawn and text is written
+ with opencv
+ - the final image is displayed
+ """
+ boxes, pred_cls, pred_scores = get_prediction(model, img_path_or_img, confidence)
+ if isinstance(img_path_or_img, str):
+ img = cv2.imread(img_path_or_img)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ else:
+ img = img_path_or_img
+ is_first = True
+ bbox = None
+ if boxes is not None:
+ for i in range(len(boxes)):
+ cls = pred_cls[i]
+ if cls == 18 and bbox is None:
+ cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th)
+ # cv2.putText(img, pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)
+ cv2.putText(img, str(pred_scores[i]), boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)
+ bbox = boxes[i]
+ return img, bbox
+
+
+
+def run_bbox_inference(input_image):
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+ model.eval()
+ out_path = os.path.join(cfg.paths.ROOT_OUT_PATH, 'gradio_examples', 'test2.png')
+ img, bbox = detect_object(model=model, img_path_or_img=input_image, confidence=0.5)
+ fig = plt.figure() # plt.figure(figsize=(20,30))
+ plt.imsave(out_path, img)
+ return img, bbox
+
+
+
+
+
+def run_barc_inference(input_image, bbox=None):
+
+ # load configs
+ cfg = get_cfg_global_updated()
+
+ model_file_complete = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, 'barc_complete', 'model_best.pth.tar')
+
+
+
+ # Select the hardware device to use for inference.
+ if torch.cuda.is_available() and cfg.device=='cuda':
+ device = torch.device('cuda', torch.cuda.current_device())
+ # torch.backends.cudnn.benchmark = True
+ else:
+ device = torch.device('cpu')
+
+ path_model_file_complete = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, model_file_complete)
+
+ # Disable gradient calculations.
+ torch.set_grad_enabled(False)
+
+ # prepare complete model
+ complete_model = ModelImageTo3d_withshape_withproj(
+ num_stage_comb=cfg.params.NUM_STAGE_COMB, num_stage_heads=cfg.params.NUM_STAGE_HEADS, \
+ num_stage_heads_pose=cfg.params.NUM_STAGE_HEADS_POSE, trans_sep=cfg.params.TRANS_SEP, \
+ arch=cfg.params.ARCH, n_joints=cfg.params.N_JOINTS, n_classes=cfg.params.N_CLASSES, \
+ n_keyp=cfg.params.N_KEYP, n_bones=cfg.params.N_BONES, n_betas=cfg.params.N_BETAS, n_betas_limbs=cfg.params.N_BETAS_LIMBS, \
+ n_breeds=cfg.params.N_BREEDS, n_z=cfg.params.N_Z, image_size=cfg.params.IMG_SIZE, \
+ silh_no_tail=cfg.params.SILH_NO_TAIL, thr_keyp_sc=cfg.params.KP_THRESHOLD, add_z_to_3d_input=cfg.params.ADD_Z_TO_3D_INPUT,
+ n_segbps=cfg.params.N_SEGBPS, add_segbps_to_3d_input=cfg.params.ADD_SEGBPS_TO_3D_INPUT, add_partseg=cfg.params.ADD_PARTSEG, n_partseg=cfg.params.N_PARTSEG, \
+ fix_flength=cfg.params.FIX_FLENGTH, structure_z_to_betas=cfg.params.STRUCTURE_Z_TO_B, structure_pose_net=cfg.params.STRUCTURE_POSE_NET,
+ nf_version=cfg.params.NF_VERSION)
+
+ # load trained model
+ print(path_model_file_complete)
+ assert os.path.isfile(path_model_file_complete)
+ print('Loading model weights from file: {}'.format(path_model_file_complete))
+ checkpoint_complete = torch.load(path_model_file_complete)
+ state_dict_complete = checkpoint_complete['state_dict']
+ complete_model.load_state_dict(state_dict_complete, strict=False)
+ complete_model = complete_model.to(device)
+
+ save_imgs_path = os.path.join(cfg.paths.ROOT_OUT_PATH, 'gradio_examples')
+ if not os.path.exists(save_imgs_path):
+ os.makedirs(save_imgs_path)
+
+ input_image_list = [input_image]
+ if bbox is not None:
+ input_bbox_list = [bbox]
+ else:
+ input_bbox_list = None
+ val_dataset = ImgCrops(image_list=input_image_list, bbox_list=input_bbox_list, dataset_mode='complete')
+ test_name_list = val_dataset.test_name_list
+ val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False,
+ num_workers=0, pin_memory=True, drop_last=False)
+
+ # run visual evaluation
+ # remark: take ACC_Joints and DATA_INFO from StanExt as this is the training dataset
+ all_results = do_visual_epoch(val_loader, complete_model, device,
+ ImgCrops.DATA_INFO,
+ weight_dict=None,
+ acc_joints=ImgCrops.ACC_JOINTS,
+ save_imgs_path=None, # save_imgs_path,
+ metrics='all',
+ test_name_list=test_name_list,
+ render_all=cfg.params.RENDER_ALL,
+ pck_thresh=cfg.params.PCK_THRESH,
+ return_results=True)
+
+ mesh = all_results[0]['mesh_posed']
+ result_path = os.path.join(save_imgs_path, test_name_list[0] + '_z')
+
+ mesh.apply_transform([[-1, 0, 0, 0],
+ [0, -1, 0, 0],
+ [0, 0, 1, 1],
+ [0, 0, 0, 1]])
+ mesh.export(file_obj=result_path + '.glb')
+ result_gltf = result_path + '.glb'
+ return [result_gltf, result_gltf]
+
+
+
+
+
+
+def run_complete_inference(input_image):
+
+ output_interm_image, output_interm_bbox = run_bbox_inference(input_image.copy())
+
+ print(output_interm_bbox)
+
+ # output_image = run_barc_inference(input_image)
+ output_image = run_barc_inference(input_image, output_interm_bbox)
+
+ return output_image
+
+
+
+
+# demo = gr.Interface(run_barc_inference, gr.Image(), "image")
+# demo = gr.Interface(run_complete_inference, gr.Image(), "image")
+
+
+
+# see: https://huggingface.co./spaces/radames/PIFu-Clothed-Human-Digitization/blob/main/PIFu/spaces.py
+
+description = '''
+# BARC
+
+#### Project Page
+* https://barc.is.tue.mpg.de/
+
+#### Description
+This is a demo for BARC. While BARC is trained on image crops, this demo uses a pretrained Faster-RCNN in order to get bounding boxes for the dogs.
+To see your result you may have to wait a minute or two, please be paitient.
+
+
+
+More
+
+#### Citation
+
+```
+@inproceedings{BARC:2022,
+ title = {BARC}: Learning to Regress {3D} Dog Shape from Images by Exploiting Breed Information,
+ author = {Rueegg, Nadine and Zuffi, Silvia and Schindler, Konrad and Black, Michael J.},
+ booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)},
+ year = {2022}
+}
+```
+
+
+'''
+
+examples = sorted(glob.glob(os.path.join(os.path.dirname(__file__), '../', 'datasets', 'test_image_crops', '*.jpg')) + glob.glob(os.path.join(os.path.dirname(__file__), '../', 'datasets', 'test_image_crops', '*.png')))
+
+
+demo = gr.Interface(
+ fn=run_complete_inference,
+ description=description,
+ # inputs=gr.Image(type="filepath", label="Input Image"),
+ inputs=gr.Image(label="Input Image"),
+ outputs=[
+ gr.Model3D(
+ clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
+ gr.File(label="Download 3D Model")
+ ],
+ examples=examples,
+ thumbnail="barc_thumbnail.png",
+ allow_flagging="never",
+ cache_examples=True
+)
+
+
+
+demo.launch(share=True)
\ No newline at end of file
diff --git a/src/bps_2d/bps_for_segmentation.py b/src/bps_2d/bps_for_segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef7382c5e875f878b296321fed6e0c46b037781e
--- /dev/null
+++ b/src/bps_2d/bps_for_segmentation.py
@@ -0,0 +1,114 @@
+
+# code idea from https://github.com/sergeyprokudin/bps
+
+import os
+import numpy as np
+from PIL import Image
+import time
+import scipy
+import scipy.spatial
+import pymp
+
+
+#####################
+QUERY_POINTS = np.asarray([30, 34, 31, 55, 29, 84, 35, 108, 34, 145, 29, 171, 27,
+ 196, 29, 228, 58, 35, 61, 55, 57, 83, 56, 109, 63, 148, 58, 164, 57, 197, 60,
+ 227, 81, 26, 87, 58, 85, 87, 89, 117, 86, 142, 89, 172, 84, 197, 88, 227, 113,
+ 32, 116, 58, 112, 88, 118, 113, 109, 147, 114, 173, 119, 201, 113, 229, 139,
+ 29, 141, 59, 142, 93, 139, 117, 146, 147, 141, 173, 142, 201, 143, 227, 170,
+ 26, 173, 59, 166, 90, 174, 117, 176, 141, 169, 175, 167, 198, 172, 227, 198,
+ 30, 195, 59, 204, 85, 198, 116, 195, 140, 198, 175, 194, 193, 199, 227, 221,
+ 26, 223, 57, 227, 83, 227, 113, 227, 140, 226, 173, 230, 196, 228, 229]).reshape((64, 2))
+#####################
+
+class SegBPS():
+
+ def __init__(self, query_points=QUERY_POINTS, size=256):
+ self.size = size
+ self.query_points = query_points
+ row, col = np.indices((self.size, self.size))
+ self.indices_rc = np.stack((row, col), axis=2) # (256, 256, 2)
+ self.pts_aranged = np.arange(64)
+ return
+
+ def _do_kdtree(self, combined_x_y_arrays, points):
+ # see https://stackoverflow.com/questions/10818546/finding-index-of-nearest-
+ # point-in-numpy-arrays-of-x-and-y-coordinates
+ mytree = scipy.spatial.cKDTree(combined_x_y_arrays)
+ dist, indexes = mytree.query(points)
+ return indexes
+
+ def calculate_bps_points(self, seg, thr=0.5, vis=False, out_path=None):
+ # seg: input segmentation image of shape (256, 256) with values between 0 and 1
+ query_val = seg[self.query_points[:, 0], self.query_points[:, 1]]
+ pts_fg = self.pts_aranged[query_val>=thr]
+ pts_bg = self.pts_aranged[query_val=thr]
+ if candidate_inds_bg.shape[0] == 0:
+ candidate_inds_bg = np.ones((1, 2)) * 128 # np.zeros((1, 2))
+ if candidate_inds_fg.shape[0] == 0:
+ candidate_inds_fg = np.ones((1, 2)) * 128 # np.zeros((1, 2))
+ # calculate nearest points
+ all_nearest_points = np.zeros((64, 2))
+ all_nearest_points[pts_fg, :] = candidate_inds_bg[self._do_kdtree(candidate_inds_bg, self.query_points[pts_fg, :]), :]
+ all_nearest_points[pts_bg, :] = candidate_inds_fg[self._do_kdtree(candidate_inds_fg, self.query_points[pts_bg, :]), :]
+ all_nearest_points_01 = all_nearest_points / 255.
+ if vis:
+ self.visualize_result(seg, all_nearest_points, out_path=out_path)
+ return all_nearest_points_01
+
+ def calculate_bps_points_batch(self, seg_batch, thr=0.5, vis=False, out_path=None):
+ # seg_batch: input segmentation image of shape (bs, 256, 256) with values between 0 and 1
+ bs = seg_batch.shape[0]
+ all_nearest_points_01_batch = np.zeros((bs, self.query_points.shape[0], 2))
+ for ind in range(0, bs): # 0.25
+ seg = seg_batch[ind, :, :]
+ all_nearest_points_01 = self.calculate_bps_points(seg, thr=thr, vis=vis, out_path=out_path)
+ all_nearest_points_01_batch[ind, :, :] = all_nearest_points_01
+ return all_nearest_points_01_batch
+
+ def visualize_result(self, seg, all_nearest_points, out_path=None):
+ import matplotlib as mpl
+ mpl.use('Agg')
+ import matplotlib.pyplot as plt
+ # img: (256, 256, 3)
+ img = (np.stack((seg, seg, seg), axis=2) * 155).astype(np.int)
+ if out_path is None:
+ ind_img = 0
+ out_path = '../test_img' + str(ind_img) + '.png'
+ fig, ax = plt.subplots()
+ plt.imshow(img)
+ plt.gca().set_axis_off()
+ plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
+ plt.margins(0,0)
+ ratio_in_out = 1 # 255
+ for idx, (y, x) in enumerate(self.query_points):
+ x = int(x*ratio_in_out)
+ y = int(y*ratio_in_out)
+ plt.scatter([x], [y], marker="x", s=50)
+ x2 = int(all_nearest_points[idx, 1])
+ y2 = int(all_nearest_points[idx, 0])
+ plt.scatter([x2], [y2], marker="o", s=50)
+ plt.plot([x, x2], [y, y2])
+ plt.savefig(out_path, bbox_inches='tight', pad_inches=0)
+ plt.close()
+ return
+
+
+
+
+
+if __name__ == "__main__":
+ ind_img = 2 # 4
+ path_seg_top = '...../pytorch-stacked-hourglass/results/dogs_hg8_ks_24_v1/test/'
+ path_seg = os.path.join(path_seg_top, 'seg_big_' + str(ind_img) + '.png')
+ img = np.asarray(Image.open(path_seg))
+ # min is 0.004, max is 0.9
+ # low values are background, high values are foreground
+ seg = img[:, :, 1] / 255.
+ # calculate points
+ bps = SegBPS()
+ bps.calculate_bps_points(seg, thr=0.5, vis=False, out_path=None)
+
+
diff --git a/src/combined_model/loss_image_to_3d_withbreedrel.py b/src/combined_model/loss_image_to_3d_withbreedrel.py
new file mode 100644
index 0000000000000000000000000000000000000000..5414f8443d9df4aac7ceb409380474d9d69ef27b
--- /dev/null
+++ b/src/combined_model/loss_image_to_3d_withbreedrel.py
@@ -0,0 +1,277 @@
+
+
+import torch
+import numpy as np
+import pickle as pkl
+
+import os
+import sys
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
+# from priors.pose_prior_35 import Prior
+# from priors.tiger_pose_prior.tiger_pose_prior import GaussianMixturePrior
+from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior
+from priors.shape_prior import ShapePrior
+from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, batch_rot2aa
+from configs.SMAL_configs import UNITY_SMAL_SHAPE_PRIOR_DOGS
+
+class Loss(torch.nn.Module):
+ def __init__(self, data_info, nf_version=None):
+ super(Loss, self).__init__()
+ self.criterion_regr = torch.nn.MSELoss() # takes the mean
+ self.criterion_class = torch.nn.CrossEntropyLoss()
+ self.data_info = data_info
+ self.register_buffer('keypoint_weights', torch.tensor(data_info.keypoint_weights)[None, :])
+ self.l_anchor = None
+ self.l_pos = None
+ self.l_neg = None
+
+ if nf_version is not None:
+ self.normalizing_flow_pose_prior = NormalizingFlowPrior(nf_version=nf_version)
+ self.shape_prior = ShapePrior(UNITY_SMAL_SHAPE_PRIOR_DOGS)
+ self.criterion_triplet = torch.nn.TripletMarginLoss(margin=1)
+
+ # load 3d data for the unity dogs (an optional shape prior for 11 breeds)
+ with open(UNITY_SMAL_SHAPE_PRIOR_DOGS, 'rb') as f:
+ data = pkl.load(f)
+ dog_betas_unity = data['dogs_betas']
+ self.dog_betas_unity = {29: torch.tensor(dog_betas_unity[0, :]).float(),
+ 91: torch.tensor(dog_betas_unity[1, :]).float(),
+ 84: torch.tensor(0.5*dog_betas_unity[3, :] + 0.5*dog_betas_unity[14, :]).float(),
+ 85: torch.tensor(dog_betas_unity[5, :]).float(),
+ 28: torch.tensor(dog_betas_unity[6, :]).float(),
+ 94: torch.tensor(dog_betas_unity[7, :]).float(),
+ 92: torch.tensor(dog_betas_unity[8, :]).float(),
+ 95: torch.tensor(dog_betas_unity[10, :]).float(),
+ 20: torch.tensor(dog_betas_unity[11, :]).float(),
+ 83: torch.tensor(dog_betas_unity[12, :]).float(),
+ 99: torch.tensor(dog_betas_unity[16, :]).float()}
+
+ def prepare_anchor_pos_neg(self, batch_size, device):
+ l0 = np.arange(0, batch_size, 2)
+ l_anchor = []
+ l_pos = []
+ l_neg = []
+ for ind in l0:
+ xx = set(np.arange(0, batch_size))
+ xx.discard(ind)
+ xx.discard(ind+1)
+ for ind2 in xx:
+ if ind2 % 2 == 0:
+ l_anchor.append(ind)
+ l_pos.append(ind + 1)
+ else:
+ l_anchor.append(ind + 1)
+ l_pos.append(ind)
+ l_neg.append(ind2)
+ self.l_anchor = torch.Tensor(l_anchor).to(torch.int64).to(device)
+ self.l_pos = torch.Tensor(l_pos).to(torch.int64).to(device)
+ self.l_neg = torch.Tensor(l_neg).to(torch.int64).to(device)
+ return
+
+
+ def forward(self, output_reproj, target_dict, weight_dict=None):
+ # output_reproj: ['vertices_smal', 'keyp_3d', 'keyp_2d', 'silh_image']
+ # target_dict: ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight']
+ batch_size = output_reproj['keyp_2d'].shape[0]
+
+ # loss on reprojected keypoints
+ output_kp_resh = (output_reproj['keyp_2d']).reshape((-1, 2))
+ target_kp_resh = (target_dict['tpts'][:, :, :2] / 64. * (256. - 1)).reshape((-1, 2))
+ weights_resh = target_dict['tpts'][:, :, 2].reshape((-1))
+ keyp_w_resh = self.keypoint_weights.repeat((batch_size, 1)).reshape((-1))
+ loss_keyp = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \
+ max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
+
+ # loss on reprojected silhouette
+ assert output_reproj['silh'].shape == (target_dict['silh'][:, None, :, :]).shape
+ silh_loss_type = 'default'
+ if silh_loss_type == 'default':
+ with torch.no_grad():
+ thr_silh = 20
+ diff = torch.norm(output_kp_resh - target_kp_resh, dim=1)
+ diff_x = diff.reshape((batch_size, -1))
+ weights_resh_x = weights_resh.reshape((batch_size, -1))
+ unweighted_kp_mean_dist = (diff_x * weights_resh_x).sum(dim=1) / ((weights_resh_x).sum(dim=1)+1e-6)
+ loss_silh_bs = ((output_reproj['silh'] - target_dict['silh'][:, None, :, :]) ** 2).sum(axis=3).sum(axis=2).sum(axis=1) / (output_reproj['silh'].shape[2]*output_reproj['silh'].shape[3])
+ loss_silh = loss_silh_bs[unweighted_kp_mean_dist 0:
+ for ind_dog in range(target_dict['breed_index'].shape[0]):
+ breed_index = np.asscalar(target_dict['breed_index'][ind_dog].detach().cpu().numpy())
+ if breed_index in self.dog_betas_unity.keys():
+ betas_target = self.dog_betas_unity[breed_index][:output_reproj['betas'].shape[1]].to(output_reproj['betas'].device)
+ betas_output = output_reproj['betas'][ind_dog, :]
+ betas_limbs_output = output_reproj['betas_limbs'][ind_dog, :]
+ loss_models3d += ((betas_limbs_output**2).sum() + ((betas_output-betas_target)**2).sum()) / (output_reproj['betas'].shape[1] + output_reproj['betas_limbs'].shape[1])
+ else:
+ weight_dict['models3d'] = 0
+
+ # shape resularization loss on shapedirs
+ # -> in the current version shapedirs are kept fixed, so we don't need those losses
+ if weight_dict['shapedirs'] > 0:
+ raise NotImplementedError
+ else:
+ loss_shapedirs = torch.zeros((1)).mean().to(output_reproj['betas'].device)
+
+ # prior on back joints (not used in cvpr 2022 paper)
+ # -> elementwise MSE loss on all 6 coefficients of 6d rotation representation
+ if 'pose_0' in weight_dict.keys():
+ if weight_dict['pose_0'] > 0:
+ pred_pose_rot6d = output_reproj['pose_rot6d']
+ w_rj_np = np.zeros((pred_pose_rot6d.shape[1]))
+ w_rj_np[[2, 3, 4, 5]] = 1.0 # back
+ w_rj = torch.tensor(w_rj_np).to(torch.float32).to(pred_pose_rot6d.device)
+ zero_rot = torch.tensor([1, 0, 0, 1, 0, 0]).to(pred_pose_rot6d.device).to(torch.float32)[None, None, :].repeat((batch_size, pred_pose_rot6d.shape[1], 1))
+ loss_pose = self.criterion_regr(pred_pose_rot6d*w_rj[None, :, None], zero_rot*w_rj[None, :, None])
+ else:
+ loss_pose = torch.zeros((1)).mean()
+
+ # pose prior
+ # -> we did experiment with different pose priors, for example:
+ # * similart to SMALify (https://github.com/benjiebob/SMALify/blob/master/smal_fitter/smal_fitter.py,
+ # https://github.com/benjiebob/SMALify/blob/master/smal_fitter/priors/pose_prior_35.py)
+ # * vae
+ # * normalizing flow pose prior
+ # -> our cvpr 2022 paper uses the normalizing flow pose prior as implemented below
+ if 'poseprior' in weight_dict.keys():
+ if weight_dict['poseprior'] > 0:
+ pred_pose_rot6d = output_reproj['pose_rot6d']
+ pred_pose = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6))).reshape((batch_size, -1, 3, 3))
+ if 'normalizing_flow_tiger' in weight_dict['poseprior_options']:
+ if output_reproj['normflow_z'] is not None:
+ loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='square')
+ else:
+ loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='square')
+ elif 'normalizing_flow_tiger_logprob' in weight_dict['poseprior_options']:
+ if output_reproj['normflow_z'] is not None:
+ loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='neg_log_prob')
+ else:
+ loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='neg_log_prob')
+ else:
+ raise NotImplementedError
+ else:
+ loss_poseprior = torch.zeros((1)).mean()
+ else:
+ weight_dict['poseprior'] = 0
+ loss_poseprior = torch.zeros((1)).mean()
+
+ # add a prior which penalizes side-movement angles for legs
+ if 'poselegssidemovement' in weight_dict.keys():
+ use_pose_legs_side_loss = True
+ else:
+ use_pose_legs_side_loss = False
+ if use_pose_legs_side_loss:
+ leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back
+ leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back
+ vec = torch.zeros((3, 1)).to(device=pred_pose.device, dtype=pred_pose.dtype)
+ vec[2] = -1
+ x0_rotmat = pred_pose
+ x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :]
+ x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :]
+ x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec
+ x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec
+ eps=0 # 1e-7
+ # use the component of the vector which points to the side
+ loss_poselegssidemovement = (x0_legs_left[:, 1]**2).mean() + (x0_legs_right[:, 1]**2).mean()
+ else:
+ loss_poselegssidemovement = torch.zeros((1)).mean()
+ weight_dict['poselegssidemovement'] = 0
+
+ # dog breed classification loss
+ dog_breed_gt = target_dict['breed_index']
+ dog_breed_pred = output_reproj['dog_breed']
+ loss_class = self.criterion_class(dog_breed_pred, dog_breed_gt)
+
+ # dog breed relationship loss
+ # -> we did experiment with many other options, but none was significantly better
+ if '4' in weight_dict['breed_options']: # we have pairs of dogs of the same breed
+ assert weight_dict['breed'] > 0
+ z = output_reproj['z']
+ # go through all pairs and compare them to each other sample
+ if self.l_anchor is None:
+ self.prepare_anchor_pos_neg(batch_size, z.device)
+ anchor = torch.index_select(z, 0, self.l_anchor)
+ positive = torch.index_select(z, 0, self.l_pos)
+ negative = torch.index_select(z, 0, self.l_neg)
+ loss_breed = self.criterion_triplet(anchor, positive, negative)
+ else:
+ loss_breed = torch.zeros((1)).mean()
+
+ # regularizarion for focal length
+ loss_flength_near_mean = torch.mean(output_reproj['flength']**2)
+ loss_flength = loss_flength_near_mean
+
+ # bodypart segmentation loss
+ if 'partseg' in weight_dict.keys():
+ if weight_dict['partseg'] > 0:
+ raise NotImplementedError
+ else:
+ loss_partseg = torch.zeros((1)).mean()
+ else:
+ weight_dict['partseg'] = 0
+ loss_partseg = torch.zeros((1)).mean()
+
+ # weight and combine losses
+ loss_keyp_weighted = loss_keyp * weight_dict['keyp']
+ loss_silh_weighted = loss_silh * weight_dict['silh']
+ loss_shapedirs_weighted = loss_shapedirs * weight_dict['shapedirs']
+ loss_pose_weighted = loss_pose * weight_dict['pose_0']
+ loss_class_weighted = loss_class * weight_dict['class']
+ loss_breed_weighted = loss_breed * weight_dict['breed']
+ loss_flength_weighted = loss_flength * weight_dict['flength']
+ loss_poseprior_weighted = loss_poseprior * weight_dict['poseprior']
+ loss_partseg_weighted = loss_partseg * weight_dict['partseg']
+ loss_models3d_weighted = loss_models3d * weight_dict['models3d']
+ loss_poselegssidemovement_weighted = loss_poselegssidemovement * weight_dict['poselegssidemovement']
+
+ ####################################################################################################
+ loss = loss_keyp_weighted + loss_silh_weighted + loss_shape_weighted + loss_pose_weighted + loss_class_weighted + \
+ loss_shapedirs_weighted + loss_breed_weighted + loss_flength_weighted + loss_poseprior_weighted + \
+ loss_partseg_weighted + loss_models3d_weighted + loss_poselegssidemovement_weighted
+ ####################################################################################################
+
+ loss_dict = {'loss': loss.item(),
+ 'loss_keyp_weighted': loss_keyp_weighted.item(), \
+ 'loss_silh_weighted': loss_silh_weighted.item(), \
+ 'loss_shape_weighted': loss_shape_weighted.item(), \
+ 'loss_shapedirs_weighted': loss_shapedirs_weighted.item(), \
+ 'loss_pose0_weighted': loss_pose_weighted.item(), \
+ 'loss_class_weighted': loss_class_weighted.item(), \
+ 'loss_breed_weighted': loss_breed_weighted.item(), \
+ 'loss_flength_weighted': loss_flength_weighted.item(), \
+ 'loss_poseprior_weighted': loss_poseprior_weighted.item(), \
+ 'loss_partseg_weighted': loss_partseg_weighted.item(), \
+ 'loss_models3d_weighted': loss_models3d_weighted.item(), \
+ 'loss_poselegssidemovement_weighted': loss_poselegssidemovement_weighted.item()}
+
+ return loss, loss_dict
+
+
+
+
diff --git a/src/combined_model/model_shape_v7.py b/src/combined_model/model_shape_v7.py
new file mode 100644
index 0000000000000000000000000000000000000000..807488d335e9a4f0870cff88a0540cc90b998f3f
--- /dev/null
+++ b/src/combined_model/model_shape_v7.py
@@ -0,0 +1,500 @@
+
+import pickle as pkl
+import numpy as np
+import torchvision.models as models
+from torchvision import transforms
+import torch
+from torch import nn
+from torch.nn.parameter import Parameter
+from kornia.geometry.subpix import dsnt # kornia 0.4.0
+
+import os
+import sys
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+from stacked_hourglass.utils.evaluation import get_preds_soft
+from stacked_hourglass import hg1, hg2, hg8
+from lifting_to_3d.linear_model import LinearModelComplete, LinearModel
+from lifting_to_3d.inn_model_for_shape import INNForShape
+from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d
+from smal_pytorch.smal_model.smal_torch_new import SMAL
+from smal_pytorch.renderer.differentiable_renderer import SilhRenderer
+from bps_2d.bps_for_segmentation import SegBPS
+from configs.SMAL_configs import UNITY_SMAL_SHAPE_PRIOR_DOGS as SHAPE_PRIOR
+from configs.SMAL_configs import MEAN_DOG_BONE_LENGTHS_NO_RED, VERTEX_IDS_TAIL
+
+
+
+class SmallLinear(nn.Module):
+ def __init__(self, input_size=64, output_size=30, linear_size=128):
+ super(SmallLinear, self).__init__()
+ self.relu = nn.ReLU(inplace=True)
+ self.w1 = nn.Linear(input_size, linear_size)
+ self.w2 = nn.Linear(linear_size, linear_size)
+ self.w3 = nn.Linear(linear_size, output_size)
+ def forward(self, x):
+ # pre-processing
+ y = self.w1(x)
+ y = self.relu(y)
+ y = self.w2(y)
+ y = self.relu(y)
+ y = self.w3(y)
+ return y
+
+
+class MyConv1d(nn.Module):
+ def __init__(self, input_size=37, output_size=30, start=True):
+ super(MyConv1d, self).__init__()
+ self.input_size = input_size
+ self.output_size = output_size
+ self.start = start
+ self.weight = Parameter(torch.ones((self.output_size)))
+ self.bias = Parameter(torch.zeros((self.output_size)))
+ def forward(self, x):
+ # pre-processing
+ if self.start:
+ y = x[:, :self.output_size]
+ else:
+ y = x[:, -self.output_size:]
+ y = y * self.weight[None, :] + self.bias[None, :]
+ return y
+
+
+class ModelShapeAndBreed(nn.Module):
+ def __init__(self, n_betas=10, n_betas_limbs=13, n_breeds=121, n_z=512, structure_z_to_betas='default'):
+ super(ModelShapeAndBreed, self).__init__()
+ self.n_betas = n_betas
+ self.n_betas_limbs = n_betas_limbs # n_betas_logscale
+ self.n_breeds = n_breeds
+ self.structure_z_to_betas = structure_z_to_betas
+ if self.structure_z_to_betas == '1dconv':
+ if not (n_z == self.n_betas+self.n_betas_limbs):
+ raise ValueError
+ # shape branch
+ self.resnet = models.resnet34(pretrained=False)
+ # replace the first layer
+ n_in = 3 + 1
+ self.resnet.conv1 = nn.Conv2d(n_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
+ # replace the last layer
+ self.resnet.fc = nn.Linear(512, n_z)
+ # softmax
+ self.soft_max = torch.nn.Softmax(dim=1)
+ # fc network (and other versions) to connect z with betas
+ p_dropout = 0.2
+ if self.structure_z_to_betas == 'default':
+ self.linear_betas = LinearModel(linear_size=1024,
+ num_stage=1,
+ p_dropout=p_dropout,
+ input_size=n_z,
+ output_size=self.n_betas)
+ self.linear_betas_limbs = LinearModel(linear_size=1024,
+ num_stage=1,
+ p_dropout=p_dropout,
+ input_size=n_z,
+ output_size=self.n_betas_limbs)
+ elif self.structure_z_to_betas == 'lin':
+ self.linear_betas = nn.Linear(n_z, self.n_betas)
+ self.linear_betas_limbs = nn.Linear(n_z, self.n_betas_limbs)
+ elif self.structure_z_to_betas == 'fc_0':
+ self.linear_betas = SmallLinear(linear_size=128, # 1024,
+ input_size=n_z,
+ output_size=self.n_betas)
+ self.linear_betas_limbs = SmallLinear(linear_size=128, # 1024,
+ input_size=n_z,
+ output_size=self.n_betas_limbs)
+ elif structure_z_to_betas == 'fc_1':
+ self.linear_betas = LinearModel(linear_size=64, # 1024,
+ num_stage=1,
+ p_dropout=0,
+ input_size=n_z,
+ output_size=self.n_betas)
+ self.linear_betas_limbs = LinearModel(linear_size=64, # 1024,
+ num_stage=1,
+ p_dropout=0,
+ input_size=n_z,
+ output_size=self.n_betas_limbs)
+ elif self.structure_z_to_betas == '1dconv':
+ self.linear_betas = MyConv1d(n_z, self.n_betas, start=True)
+ self.linear_betas_limbs = MyConv1d(n_z, self.n_betas_limbs, start=False)
+ elif self.structure_z_to_betas == 'inn':
+ self.linear_betas_and_betas_limbs = INNForShape(self.n_betas, self.n_betas_limbs, betas_scale=1.0, betas_limbs_scale=1.0)
+ else:
+ raise ValueError
+ # network to connect latent shape vector z with dog breed classification
+ self.linear_breeds = LinearModel(linear_size=1024, # 1024,
+ num_stage=1,
+ p_dropout=p_dropout,
+ input_size=n_z,
+ output_size=self.n_breeds)
+ # shape multiplicator
+ self.shape_multiplicator_np = np.ones(self.n_betas)
+ with open(SHAPE_PRIOR, 'rb') as file:
+ u = pkl._Unpickler(file)
+ u.encoding = 'latin1'
+ res = u.load()
+ # shape predictions are centered around the mean dog of our dog model
+ self.betas_mean_np = res['dog_cluster_mean']
+
+ def forward(self, img, seg_raw=None, seg_prep=None):
+ # img is the network input image
+ # seg_raw is before softmax and subtracting 0.5
+ # seg_prep would be the prepared_segmentation
+ if seg_prep is None:
+ seg_prep = self.soft_max(seg_raw)[:, 1:2, :, :] - 0.5
+ input_img_and_seg = torch.cat((img, seg_prep), axis=1)
+ res_output = self.resnet(input_img_and_seg)
+ dog_breed_output = self.linear_breeds(res_output)
+ if self.structure_z_to_betas == 'inn':
+ shape_output_orig, shape_limbs_output_orig = self.linear_betas_and_betas_limbs(res_output)
+ else:
+ shape_output_orig = self.linear_betas(res_output) * 0.1
+ betas_mean = torch.tensor(self.betas_mean_np).float().to(img.device)
+ shape_output = shape_output_orig + betas_mean[None, 0:self.n_betas]
+ shape_limbs_output_orig = self.linear_betas_limbs(res_output)
+ shape_limbs_output = shape_limbs_output_orig * 0.1
+ output_dict = {'z': res_output,
+ 'breeds': dog_breed_output,
+ 'betas': shape_output_orig,
+ 'betas_limbs': shape_limbs_output_orig}
+ return output_dict
+
+
+
+class LearnableShapedirs(nn.Module):
+ def __init__(self, sym_ids_dict, shapedirs_init, n_betas, n_betas_fixed=10):
+ super(LearnableShapedirs, self).__init__()
+ # shapedirs_init = self.smal.shapedirs.detach()
+ self.n_betas = n_betas
+ self.n_betas_fixed = n_betas_fixed
+ self.sym_ids_dict = sym_ids_dict
+ sym_left_ids = self.sym_ids_dict['left']
+ sym_right_ids = self.sym_ids_dict['right']
+ sym_center_ids = self.sym_ids_dict['center']
+ self.n_center = sym_center_ids.shape[0]
+ self.n_left = sym_left_ids.shape[0]
+ self.n_sd = self.n_betas - self.n_betas_fixed # number of learnable shapedirs
+ # get indices to go from half_shapedirs to shapedirs
+ inds_back = np.zeros((3889))
+ for ind in range(0, sym_center_ids.shape[0]):
+ ind_in_forward = sym_center_ids[ind]
+ inds_back[ind_in_forward] = ind
+ for ind in range(0, sym_left_ids.shape[0]):
+ ind_in_forward = sym_left_ids[ind]
+ inds_back[ind_in_forward] = sym_center_ids.shape[0] + ind
+ for ind in range(0, sym_right_ids.shape[0]):
+ ind_in_forward = sym_right_ids[ind]
+ inds_back[ind_in_forward] = sym_center_ids.shape[0] + sym_left_ids.shape[0] + ind
+ self.register_buffer('inds_back_torch', torch.Tensor(inds_back).long())
+ # self.smal.shapedirs: (51, 11667)
+ # shapedirs: (3889, 3, n_sd)
+ # shapedirs_half: (2012, 3, n_sd)
+ sd = shapedirs_init[:self.n_betas, :].permute((1, 0)).reshape((-1, 3, self.n_betas))
+ self.register_buffer('sd', sd)
+ sd_center = sd[sym_center_ids, :, self.n_betas_fixed:]
+ sd_left = sd[sym_left_ids, :, self.n_betas_fixed:]
+ self.register_parameter('learnable_half_shapedirs_c0', torch.nn.Parameter(sd_center[:, 0, :].detach()))
+ self.register_parameter('learnable_half_shapedirs_c2', torch.nn.Parameter(sd_center[:, 2, :].detach()))
+ self.register_parameter('learnable_half_shapedirs_l0', torch.nn.Parameter(sd_left[:, 0, :].detach()))
+ self.register_parameter('learnable_half_shapedirs_l1', torch.nn.Parameter(sd_left[:, 1, :].detach()))
+ self.register_parameter('learnable_half_shapedirs_l2', torch.nn.Parameter(sd_left[:, 2, :].detach()))
+ def forward(self):
+ device = self.learnable_half_shapedirs_c0.device
+ half_shapedirs_center = torch.stack((self.learnable_half_shapedirs_c0, \
+ torch.zeros((self.n_center, self.n_sd)).to(device), \
+ self.learnable_half_shapedirs_c2), axis=1)
+ half_shapedirs_left = torch.stack((self.learnable_half_shapedirs_l0, \
+ self.learnable_half_shapedirs_l1, \
+ self.learnable_half_shapedirs_l2), axis=1)
+ half_shapedirs_right = torch.stack((self.learnable_half_shapedirs_l0, \
+ - self.learnable_half_shapedirs_l1, \
+ self.learnable_half_shapedirs_l2), axis=1)
+ half_shapedirs_tot = torch.cat((half_shapedirs_center, half_shapedirs_left, half_shapedirs_right))
+ shapedirs = torch.index_select(half_shapedirs_tot, dim=0, index=self.inds_back_torch)
+ shapedirs_complete = torch.cat((self.sd[:, :, :self.n_betas_fixed], shapedirs), axis=2) # (3889, 3, n_sd)
+ shapedirs_complete_prepared = torch.cat((self.sd[:, :, :10], shapedirs), axis=2).reshape((-1, 30)).permute((1, 0)) # (n_sd, 11667)
+ return shapedirs_complete, shapedirs_complete_prepared
+
+
+
+
+
+class ModelImageToBreed(nn.Module):
+ def __init__(self, arch='hg8', n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=7, n_breeds=121, image_size=256, n_z=512, thr_keyp_sc=None, add_partseg=True):
+ super(ModelImageToBreed, self).__init__()
+ self.n_classes = n_classes
+ self.n_partseg = n_partseg
+ self.n_betas = n_betas
+ self.n_betas_limbs = n_betas_limbs
+ self.n_keyp = n_keyp
+ self.n_bones = n_bones
+ self.n_breeds = n_breeds
+ self.image_size = image_size
+ self.upsample_seg = True
+ self.threshold_scores = thr_keyp_sc
+ self.n_z = n_z
+ self.add_partseg = add_partseg
+ # ------------------------------ STACKED HOUR GLASS ------------------------------
+ if arch == 'hg8':
+ self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg)
+ else:
+ raise Exception('unrecognised model architecture: ' + arch)
+ # ------------------------------ SHAPE AND BREED MODEL ------------------------------
+ self.breed_model = ModelShapeAndBreed(n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z)
+ def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None):
+ batch_size = input_img.shape[0]
+ device = input_img.device
+ # ------------------------------ STACKED HOUR GLASS ------------------------------
+ hourglass_out_dict = self.stacked_hourglass(input_img)
+ last_seg = hourglass_out_dict['seg_final']
+ last_heatmap = hourglass_out_dict['out_list_kp'][-1]
+ # - prepare keypoints (from heatmap)
+ # normalize predictions -> from logits to probability distribution
+ # last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1))
+ # keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2)
+ # keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2)
+ keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True)
+ if self.threshold_scores is not None:
+ scores[scores>self.threshold_scores] = 1.0
+ scores[scores<=self.threshold_scores] = 0.0
+ # ------------------------------ SHAPE AND BREED MODEL ------------------------------
+ # breed_model takes as input the image as well as the predicted segmentation map
+ # -> we need to split up ModelImageTo3d, such that we can use the silhouette
+ resnet_output = self.breed_model(img=input_img, seg_raw=last_seg)
+ pred_breed = resnet_output['breeds'] # (bs, n_breeds)
+ pred_betas = resnet_output['betas']
+ pred_betas_limbs = resnet_output['betas_limbs']
+ small_output = {'keypoints_norm': keypoints_norm,
+ 'keypoints_scores': scores}
+ small_output_reproj = {'betas': pred_betas,
+ 'betas_limbs': pred_betas_limbs,
+ 'dog_breed': pred_breed}
+ return small_output, None, small_output_reproj
+
+class ModelImageTo3d_withshape_withproj(nn.Module):
+ def __init__(self, arch='hg8', num_stage_comb=2, num_stage_heads=1, num_stage_heads_pose=1, trans_sep=False, n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=6, n_breeds=121, image_size=256, n_z=512, n_segbps=64*2, thr_keyp_sc=None, add_z_to_3d_input=True, add_segbps_to_3d_input=False, add_partseg=True, silh_no_tail=True, fix_flength=False, render_partseg=False, structure_z_to_betas='default', structure_pose_net='default', nf_version=None):
+ super(ModelImageTo3d_withshape_withproj, self).__init__()
+ self.n_classes = n_classes
+ self.n_partseg = n_partseg
+ self.n_betas = n_betas
+ self.n_betas_limbs = n_betas_limbs
+ self.n_keyp = n_keyp
+ self.n_bones = n_bones
+ self.n_breeds = n_breeds
+ self.image_size = image_size
+ self.threshold_scores = thr_keyp_sc
+ self.upsample_seg = True
+ self.silh_no_tail = silh_no_tail
+ self.add_z_to_3d_input = add_z_to_3d_input
+ self.add_segbps_to_3d_input = add_segbps_to_3d_input
+ self.add_partseg = add_partseg
+ assert (not self.add_segbps_to_3d_input) or (not self.add_z_to_3d_input)
+ self.n_z = n_z
+ if add_segbps_to_3d_input:
+ self.n_segbps = n_segbps # 64
+ self.segbps_model = SegBPS()
+ else:
+ self.n_segbps = 0
+ self.fix_flength = fix_flength
+ self.render_partseg = render_partseg
+ self.structure_z_to_betas = structure_z_to_betas
+ self.structure_pose_net = structure_pose_net
+ assert self.structure_pose_net in ['default', 'vae', 'normflow']
+ self.nf_version = nf_version
+ self.register_buffer('betas_zeros', torch.zeros((1, self.n_betas)))
+ self.register_buffer('mean_dog_bone_lengths', torch.tensor(MEAN_DOG_BONE_LENGTHS_NO_RED, dtype=torch.float32))
+ p_dropout = 0.2 # 0.5
+ # ------------------------------ SMAL MODEL ------------------------------
+ self.smal = SMAL(template_name='neutral')
+ # New for rendering without tail
+ f_np = self.smal.faces.detach().cpu().numpy()
+ self.f_no_tail_np = f_np[np.isin(f_np[:,:], VERTEX_IDS_TAIL).sum(axis=1)==0, :]
+ # in theory we could optimize for improved shapedirs, but we do not do that
+ # -> would need to implement regularizations
+ # -> there are better ways than changing the shapedirs
+ self.model_learnable_shapedirs = LearnableShapedirs(self.smal.sym_ids_dict, self.smal.shapedirs.detach(), self.n_betas, 10)
+ # ------------------------------ STACKED HOUR GLASS ------------------------------
+ if arch == 'hg8':
+ self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg)
+ else:
+ raise Exception('unrecognised model architecture: ' + arch)
+ # ------------------------------ SHAPE AND BREED MODEL ------------------------------
+ self.breed_model = ModelShapeAndBreed(n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z, structure_z_to_betas=self.structure_z_to_betas)
+ # ------------------------------ LINEAR 3D MODEL ------------------------------
+ # 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength}
+ self.soft_max = torch.nn.Softmax(dim=1)
+ input_size = self.n_keyp*3 + self.n_bones
+ self.model_3d = LinearModelComplete(linear_size=1024,
+ num_stage_comb=num_stage_comb,
+ num_stage_heads=num_stage_heads,
+ num_stage_heads_pose=num_stage_heads_pose,
+ trans_sep=trans_sep,
+ p_dropout=p_dropout, # 0.5,
+ input_size=input_size,
+ intermediate_size=1024,
+ output_info=None,
+ n_joints=n_joints,
+ n_z=self.n_z,
+ add_z_to_3d_input=self.add_z_to_3d_input,
+ n_segbps=self.n_segbps,
+ add_segbps_to_3d_input=self.add_segbps_to_3d_input,
+ structure_pose_net=self.structure_pose_net,
+ nf_version = self.nf_version)
+ # ------------------------------ RENDERING ------------------------------
+ self.silh_renderer = SilhRenderer(image_size)
+
+ def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None):
+ batch_size = input_img.shape[0]
+ device = input_img.device
+ # ------------------------------ STACKED HOUR GLASS ------------------------------
+ hourglass_out_dict = self.stacked_hourglass(input_img)
+ last_seg = hourglass_out_dict['seg_final']
+ last_heatmap = hourglass_out_dict['out_list_kp'][-1]
+ # - prepare keypoints (from heatmap)
+ # normalize predictions -> from logits to probability distribution
+ # last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1))
+ # keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2)
+ # keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2)
+ keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True)
+ if self.threshold_scores is not None:
+ scores[scores>self.threshold_scores] = 1.0
+ scores[scores<=self.threshold_scores] = 0.0
+ # ------------------------------ LEARNABLE SHAPE MODEL ------------------------------
+ # in our cvpr 2022 paper we do not change the shapedirs
+ # learnable_sd_complete has shape (3889, 3, n_sd)
+ # learnable_sd_complete_prepared has shape (n_sd, 11667)
+ learnable_sd_complete, learnable_sd_complete_prepared = self.model_learnable_shapedirs()
+ shapedirs_sel = learnable_sd_complete_prepared # None
+ # ------------------------------ SHAPE AND BREED MODEL ------------------------------
+ # breed_model takes as input the image as well as the predicted segmentation map
+ # -> we need to split up ModelImageTo3d, such that we can use the silhouette
+ resnet_output = self.breed_model(img=input_img, seg_raw=last_seg)
+ pred_breed = resnet_output['breeds'] # (bs, n_breeds)
+ pred_z = resnet_output['z']
+ # - prepare shape
+ pred_betas = resnet_output['betas']
+ pred_betas_limbs = resnet_output['betas_limbs']
+ # - calculate bone lengths
+ with torch.no_grad():
+ use_mean_bone_lengths = False
+ if use_mean_bone_lengths:
+ bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))])
+ else:
+ assert (bone_lengths_prepared is None)
+ bone_lengths_prepared = self.smal.caclulate_bone_lengths(pred_betas, pred_betas_limbs, shapedirs_sel=shapedirs_sel, short=True)
+ # ------------------------------ LINEAR 3D MODEL ------------------------------
+ # 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength}
+ # prepare input for 2d-to-3d network
+ keypoints_prepared = torch.cat((keypoints_norm, scores), axis=2)
+ if bone_lengths_prepared is None:
+ bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))])
+ # should we add silhouette to 3d input? should we add z?
+ if self.add_segbps_to_3d_input:
+ seg_raw = last_seg
+ seg_prep_bps = self.soft_max(seg_raw)[:, 1, :, :] # class 1 is the dog
+ with torch.no_grad():
+ seg_prep_np = seg_prep_bps.detach().cpu().numpy()
+ bps_output_np = self.segbps_model.calculate_bps_points_batch(seg_prep_np) # (bs, 64, 2)
+ bps_output = torch.tensor(bps_output_np, dtype=torch.float32).to(device).reshape((batch_size, -1))
+ bps_output_prep = bps_output * 2. - 1
+ input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1)
+ input_vec = torch.cat((input_vec_keyp_bones, bps_output_prep), dim=1)
+ elif self.add_z_to_3d_input:
+ # we do not use this in our cvpr 2022 version
+ input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1)
+ input_vec_additional = pred_z
+ input_vec = torch.cat((input_vec_keyp_bones, input_vec_additional), dim=1)
+ else:
+ input_vec = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1)
+ # predict 3d parameters (those are normalized, we need to correct mean and std in a next step)
+ output = self.model_3d(input_vec)
+ # add predicted keypoints to the output dict
+ output['keypoints_norm'] = keypoints_norm
+ output['keypoints_scores'] = scores
+ # - denormalize 3d parameters -> so far predictions were normalized, now we denormalize them again
+ pred_trans = output['trans'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3)
+ if self.structure_pose_net == 'default':
+ pred_pose_rot6d = output['pose'] + norm_dict['pose_rot6d_mean'][None, :]
+ elif self.structure_pose_net == 'normflow':
+ pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :])
+ pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :]
+ pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros
+ else:
+ pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :])
+ pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :]
+ pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros
+ pred_pose_reshx33 = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6)))
+ pred_pose = pred_pose_reshx33.reshape((batch_size, -1, 3, 3))
+ pred_pose_rot6d = rotmat_to_rot6d(pred_pose_reshx33).reshape((batch_size, -1, 6))
+
+ if self.fix_flength:
+ output['flength'] = torch.zeros_like(output['flength'])
+ pred_flength = torch.ones_like(output['flength'])*2100 # norm_dict['flength_mean'][None, :]
+ else:
+ pred_flength_orig = output['flength'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1)
+ pred_flength = pred_flength_orig.clone() # torch.abs(pred_flength_orig)
+ pred_flength[pred_flength_orig<=0] = norm_dict['flength_mean'][None, :]
+
+ # ------------------------------ RENDERING ------------------------------
+ # get 3d model (SMAL)
+ V, keyp_green_3d, _ = self.smal(beta=pred_betas, betas_limbs=pred_betas_limbs, pose=pred_pose, trans=pred_trans, get_skin=True, keyp_conf='green', shapedirs_sel=shapedirs_sel)
+ keyp_3d = keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3)
+ # render silhouette
+ faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
+ if not self.silh_no_tail:
+ pred_silh_images, pred_keyp = self.silh_renderer(vertices=V,
+ points=keyp_3d, faces=faces_prep, focal_lengths=pred_flength)
+ else:
+ faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1))
+ pred_silh_images, pred_keyp = self.silh_renderer(vertices=V,
+ points=keyp_3d, faces=faces_no_tail_prep, focal_lengths=pred_flength)
+ # get torch 'Meshes'
+ torch_meshes = self.silh_renderer.get_torch_meshes(vertices=V, faces=faces_prep)
+
+ # render body parts (not part of cvpr 2022 version)
+ if self.render_partseg:
+ raise NotImplementedError
+ else:
+ partseg_images = None
+ partseg_images_hg = None
+
+ # ------------------------------ PREPARE OUTPUT ------------------------------
+ # create output dictionarys
+ # output: contains all output from model_image_to_3d
+ # output_unnorm: same as output, but normalizations are undone
+ # output_reproj: smal output and reprojected keypoints as well as silhouette
+ keypoints_heatmap_256 = (output['keypoints_norm'] / 2. + 0.5) * (self.image_size - 1)
+ output_unnorm = {'pose_rotmat': pred_pose,
+ 'flength': pred_flength,
+ 'trans': pred_trans,
+ 'keypoints':keypoints_heatmap_256}
+ output_reproj = {'vertices_smal': V,
+ 'torch_meshes': torch_meshes,
+ 'keyp_3d': keyp_3d,
+ 'keyp_2d': pred_keyp,
+ 'silh': pred_silh_images,
+ 'betas': pred_betas,
+ 'betas_limbs': pred_betas_limbs,
+ 'pose_rot6d': pred_pose_rot6d, # used for pose prior...
+ 'dog_breed': pred_breed,
+ 'shapedirs': shapedirs_sel,
+ 'z': pred_z,
+ 'flength_unnorm': pred_flength,
+ 'flength': output['flength'],
+ 'partseg_images_rend': partseg_images,
+ 'partseg_images_hg_nograd': partseg_images_hg,
+ 'normflow_z': output['normflow_z']}
+
+ return output, output_unnorm, output_reproj
+
+ def render_vis_nograd(self, vertices, focal_lengths, color=0):
+ # this function is for visualization only
+ # vertices: (bs, n_verts, 3)
+ # focal_lengths: (bs, 1)
+ # color: integer, either 0 or 1
+ # returns a torch tensor of shape (bs, image_size, image_size, 3)
+ with torch.no_grad():
+ batch_size = vertices.shape[0]
+ faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
+ visualizations = self.silh_renderer.get_visualization_nograd(vertices,
+ faces_prep, focal_lengths, color=color)
+ return visualizations
+
diff --git a/src/combined_model/train_main_image_to_3d_withbreedrel.py b/src/combined_model/train_main_image_to_3d_withbreedrel.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c06655d08cbc60e1239147aa01272cd901fa04b
--- /dev/null
+++ b/src/combined_model/train_main_image_to_3d_withbreedrel.py
@@ -0,0 +1,470 @@
+
+import torch
+import torch.nn as nn
+import torch.backends.cudnn
+import torch.nn.parallel
+from tqdm import tqdm
+import os
+import pathlib
+from matplotlib import pyplot as plt
+import cv2
+import numpy as np
+import torch
+import trimesh
+
+import sys
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
+from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft
+from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image
+from metrics.metrics import Metrics
+from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS
+
+
+# ---------------------------------------------------------------------------------------------------------------------------
+def do_training_epoch(train_loader, model, loss_module, device, data_info, optimiser, quiet=False, acc_joints=None, weight_dict=None):
+ losses = AverageMeter()
+ losses_keyp = AverageMeter()
+ losses_silh = AverageMeter()
+ losses_shape = AverageMeter()
+ losses_pose = AverageMeter()
+ losses_class = AverageMeter()
+ losses_breed = AverageMeter()
+ losses_partseg = AverageMeter()
+ accuracies = AverageMeter()
+ # Put the model in training mode.
+ model.train()
+ # prepare progress bar
+ iterable = enumerate(train_loader)
+ progress = None
+ if not quiet:
+ progress = tqdm(iterable, desc='Train', total=len(train_loader), ascii=True, leave=False)
+ iterable = progress
+ # information for normalization
+ norm_dict = {
+ 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
+ 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
+ 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
+ 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
+ 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
+ # prepare variables, put them on the right device
+ for i, (input, target_dict) in iterable:
+ batch_size = input.shape[0]
+ for key in target_dict.keys():
+ if key == 'breed_index':
+ target_dict[key] = target_dict[key].long().to(device)
+ elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']:
+ target_dict[key] = target_dict[key].float().to(device)
+ elif key == 'has_seg':
+ target_dict[key] = target_dict[key].to(device)
+ else:
+ pass
+ input = input.float().to(device)
+
+ # ----------------------- do training step -----------------------
+ assert model.training, 'model must be in training mode.'
+ with torch.enable_grad():
+ # ----- forward pass -----
+ output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict)
+ # ----- loss -----
+ loss, loss_dict = loss_module(output_reproj=output_reproj,
+ target_dict=target_dict,
+ weight_dict=weight_dict)
+ # ----- backward pass and parameter update -----
+ optimiser.zero_grad()
+ loss.backward()
+ optimiser.step()
+ # ----------------------------------------------------------------
+
+ # prepare losses for progress bar
+ bs_fake = 1 # batch_size
+ losses.update(loss_dict['loss'], bs_fake)
+ losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake)
+ losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake)
+ losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake)
+ losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake)
+ losses_class.update(loss_dict['loss_class_weighted'], bs_fake)
+ losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake)
+ losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake)
+ acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model'
+ accuracies.update(acc, bs_fake)
+ # Show losses as part of the progress bar.
+ if progress is not None:
+ my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}'.format(
+ loss=losses.avg,
+ loss_keyp=losses_keyp.avg,
+ loss_silh=losses_silh.avg,
+ loss_shape=losses_shape.avg,
+ loss_pose=losses_pose.avg,
+ loss_class=losses_class.avg,
+ loss_breed=losses_breed.avg,
+ loss_partseg=losses_partseg.avg
+ )
+ progress.set_postfix_str(my_string)
+
+ return my_string, accuracies.avg
+
+
+# ---------------------------------------------------------------------------------------------------------------------------
+def do_validation_epoch(val_loader, model, loss_module, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, len_dataset=None):
+ losses = AverageMeter()
+ losses_keyp = AverageMeter()
+ losses_silh = AverageMeter()
+ losses_shape = AverageMeter()
+ losses_pose = AverageMeter()
+ losses_class = AverageMeter()
+ losses_breed = AverageMeter()
+ losses_partseg = AverageMeter()
+ accuracies = AverageMeter()
+ if save_imgs_path is not None:
+ pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True)
+ # Put the model in evaluation mode.
+ model.eval()
+ # prepare progress bar
+ iterable = enumerate(val_loader)
+ progress = None
+ if not quiet:
+ progress = tqdm(iterable, desc='Valid', total=len(val_loader), ascii=True, leave=False)
+ iterable = progress
+ # summarize information for normalization
+ norm_dict = {
+ 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
+ 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
+ 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
+ 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
+ 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
+ batch_size = val_loader.batch_size
+ # prepare variables, put them on the right device
+ my_step = 0
+ for i, (input, target_dict) in iterable:
+ curr_batch_size = input.shape[0]
+ for key in target_dict.keys():
+ if key == 'breed_index':
+ target_dict[key] = target_dict[key].long().to(device)
+ elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']:
+ target_dict[key] = target_dict[key].float().to(device)
+ elif key == 'has_seg':
+ target_dict[key] = target_dict[key].to(device)
+ else:
+ pass
+ input = input.float().to(device)
+
+ # ----------------------- do validation step -----------------------
+ with torch.no_grad():
+ # ----- forward pass -----
+ # output: (['pose', 'flength', 'trans', 'keypoints_norm', 'keypoints_scores'])
+ # output_unnorm: (['pose_rotmat', 'flength', 'trans', 'keypoints'])
+ # output_reproj: (['vertices_smal', 'torch_meshes', 'keyp_3d', 'keyp_2d', 'silh', 'betas', 'pose_rot6d', 'dog_breed', 'shapedirs', 'z', 'flength_unnorm', 'flength'])
+ # target_dict: (['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'sim_breed_index', 'ind_dataset', 'silh'])
+ output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict)
+ # ----- loss -----
+ if metrics == 'no_loss':
+ loss, loss_dict = loss_module(output_reproj=output_reproj,
+ target_dict=target_dict,
+ weight_dict=weight_dict)
+ # ----------------------------------------------------------------
+
+ if i == 0:
+ if len_dataset is None:
+ len_data = val_loader.batch_size * len(val_loader) # 1703
+ else:
+ len_data = len_dataset
+ if metrics == 'all' or metrics == 'no_loss':
+ pck = np.zeros((len_data))
+ pck_by_part = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS}
+ acc_sil_2d = np.zeros(len_data)
+
+ all_betas = np.zeros((len_data, output_reproj['betas'].shape[1]))
+ all_betas_limbs = np.zeros((len_data, output_reproj['betas_limbs'].shape[1]))
+ all_z = np.zeros((len_data, output_reproj['z'].shape[1]))
+ all_pose_rotmat = np.zeros((len_data, output_unnorm['pose_rotmat'].shape[1], 3, 3))
+ all_flength = np.zeros((len_data, output_unnorm['flength'].shape[1]))
+ all_trans = np.zeros((len_data, output_unnorm['trans'].shape[1]))
+ all_breed_indices = np.zeros((len_data))
+ all_image_names = [] # len_data * [None]
+
+ index = i
+ ind_img = 0
+ if save_imgs_path is not None:
+ # render predicted 3d models
+ visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'],
+ focal_lengths=output_unnorm['flength'],
+ color=0) # color=2)
+ for ind_img in range(len(target_dict['index'])):
+ try:
+ if test_name_list is not None:
+ img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
+ img_name = img_name.split('.')[0]
+ else:
+ img_name = str(index) + '_' + str(ind_img)
+ # save image with predicted keypoints
+ out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png'
+ pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1)
+ pred_unp_maxval = output['keypoints_scores'][ind_img, :, :]
+ pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1)
+ inp_img = input[ind_img, :, :, :].detach().clone()
+ save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3
+ # save predicted 3d model (front view)
+ pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
+ pred_tex_max = np.max(pred_tex, axis=2)
+ out_path = save_imgs_path + '/tex_pred_' + img_name + '.png'
+ plt.imsave(out_path, pred_tex)
+ input_image = input[ind_img, :, :, :].detach().clone()
+ for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m)
+ input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0)
+ im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
+ im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
+ out_path = save_imgs_path + '/comp_pred_' + img_name + '.png'
+ plt.imsave(out_path, im_masked)
+ # save predicted 3d model (side view)
+ vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :]
+ roll = np.pi / 2 * torch.ones(1).float().to(device)
+ pitch = np.pi / 2 * torch.ones(1).float().to(device)
+ tensor_0 = torch.zeros(1).float().to(device)
+ tensor_1 = torch.ones(1).float().to(device)
+ RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3)
+ RY = torch.stack([
+ torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
+ torch.stack([tensor_0, tensor_1, tensor_0]),
+ torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3)
+ vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((curr_batch_size, -1, 3))
+ vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16
+
+ visualizations_rot = model.render_vis_nograd(vertices=vertices_rot,
+ focal_lengths=output_unnorm['flength'],
+ color=0) # 2)
+ pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
+ pred_tex_max = np.max(pred_tex, axis=2)
+ out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png'
+ plt.imsave(out_path, pred_tex)
+ if render_all:
+ # save input image
+ inp_img = input[ind_img, :, :, :].detach().clone()
+ out_path = save_imgs_path + '/image_' + img_name + '.png'
+ save_input_image(inp_img, out_path)
+ # save mesh
+ V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy()
+ Faces = model.smal.f
+ mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False)
+ mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj')
+ except:
+ print('dont save an image')
+
+ if metrics == 'all' or metrics == 'no_loss':
+ # prepare a dictionary with all the predicted results
+ preds = {}
+ preds['betas'] = output_reproj['betas'].cpu().detach().numpy()
+ preds['betas_limbs'] = output_reproj['betas_limbs'].cpu().detach().numpy()
+ preds['z'] = output_reproj['z'].cpu().detach().numpy()
+ preds['pose_rotmat'] = output_unnorm['pose_rotmat'].cpu().detach().numpy()
+ preds['flength'] = output_unnorm['flength'].cpu().detach().numpy()
+ preds['trans'] = output_unnorm['trans'].cpu().detach().numpy()
+ preds['breed_index'] = target_dict['breed_index'].cpu().detach().numpy().reshape((-1))
+ img_names = []
+ for ind_img2 in range(0, output_reproj['betas'].shape[0]):
+ if test_name_list is not None:
+ img_name2 = test_name_list[int(target_dict['index'][ind_img2].cpu().detach().numpy())].replace('/', '_')
+ img_name2 = img_name2.split('.')[0]
+ else:
+ img_name2 = str(index) + '_' + str(ind_img2)
+ img_names.append(img_name2)
+ preds['image_names'] = img_names
+ # prepare keypoints for PCK calculation - predicted as well as ground truth
+ pred_keypoints_norm = output['keypoints_norm'] # -1 to 1
+ pred_keypoints_256 = output_reproj['keyp_2d']
+ pred_keypoints = pred_keypoints_256
+ gt_keypoints_256 = target_dict['tpts'][:, :, :2] / 64. * (256. - 1)
+ gt_keypoints_norm = gt_keypoints_256 / 256 / 0.5 - 1
+ gt_keypoints = torch.cat((gt_keypoints_256, target_dict['tpts'][:, :, 2:3]), dim=2) # gt_keypoints_norm
+ # prepare silhouette for IoU calculation - predicted as well as ground truth
+ has_seg = target_dict['has_seg']
+ img_border_mask = target_dict['img_border_mask'][:, 0, :, :]
+ gtseg = target_dict['silh']
+ synth_silhouettes = output_reproj['silh'][:, 0, :, :] # output_reproj['silh']
+ synth_silhouettes[synth_silhouettes>0.5] = 1
+ synth_silhouettes[synth_silhouettes<0.5] = 0
+ # calculate PCK as well as IoU (similar to WLDO)
+ preds['acc_PCK'] = Metrics.PCK(
+ pred_keypoints, gt_keypoints,
+ gtseg, has_seg, idxs=EVAL_KEYPOINTS,
+ thresh_range=[pck_thresh], # [0.15],
+ )
+ preds['acc_IOU'] = Metrics.IOU(
+ synth_silhouettes, gtseg,
+ img_border_mask, mask=has_seg
+ )
+ for group, group_kps in KEYPOINT_GROUPS.items():
+ preds[f'{group}_PCK'] = Metrics.PCK(
+ pred_keypoints, gt_keypoints, gtseg, has_seg,
+ thresh_range=[pck_thresh], # [0.15],
+ idxs=group_kps
+ )
+ # add results for all images in this batch to lists
+ curr_batch_size = pred_keypoints_256.shape[0]
+ if not (preds['acc_PCK'].data.cpu().numpy().shape == (pck[my_step * batch_size:my_step * batch_size + curr_batch_size]).shape):
+ import pdb; pdb.set_trace()
+ pck[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy()
+ acc_sil_2d[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy()
+ for part in pck_by_part:
+ pck_by_part[part][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy()
+ all_betas[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas']
+ all_betas_limbs[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas_limbs']
+ all_z[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['z']
+ all_pose_rotmat[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['pose_rotmat']
+ all_flength[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['flength']
+ all_trans[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['trans']
+ all_breed_indices[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['breed_index']
+ all_image_names.extend(preds['image_names'])
+ # update progress bar
+ if progress is not None:
+ my_string = "PCK: {0:.2f}, IOU: {1:.2f}".format(
+ pck[:(my_step * batch_size + curr_batch_size)].mean(),
+ acc_sil_2d[:(my_step * batch_size + curr_batch_size)].mean())
+ progress.set_postfix_str(my_string)
+ else:
+ # measure accuracy and record loss
+ bs_fake = 1 # batch_size
+ losses.update(loss_dict['loss'], bs_fake)
+ losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake)
+ losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake)
+ losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake)
+ losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake)
+ losses_class.update(loss_dict['loss_class_weighted'], bs_fake)
+ losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake)
+ losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake)
+ acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model'
+ accuracies.update(acc, bs_fake)
+ # Show losses as part of the progress bar.
+ if progress is not None:
+ my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}'.format(
+ loss=losses.avg,
+ loss_keyp=losses_keyp.avg,
+ loss_silh=losses_silh.avg,
+ loss_shape=losses_shape.avg,
+ loss_pose=losses_pose.avg,
+ loss_class=losses_class.avg,
+ loss_breed=losses_breed.avg,
+ loss_partseg=losses_partseg.avg
+ )
+ progress.set_postfix_str(my_string)
+ my_step += 1
+ if metrics == 'all':
+ summary = {'pck': pck, 'acc_sil_2d': acc_sil_2d, 'pck_by_part':pck_by_part,
+ 'betas': all_betas, 'betas_limbs': all_betas_limbs, 'z': all_z, 'pose_rotmat': all_pose_rotmat,
+ 'flenght': all_flength, 'trans': all_trans, 'image_names': all_image_names, 'breed_indices': all_breed_indices}
+ return my_string, summary
+ elif metrics == 'no_loss':
+ return my_string, np.average(np.asarray(acc_sil_2d))
+ else:
+ return my_string, accuracies.avg
+
+
+# ---------------------------------------------------------------------------------------------------------------------------
+def do_visual_epoch(val_loader, model, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, return_results=False):
+ if save_imgs_path is not None:
+ pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True)
+ all_results = []
+
+ # Put the model in evaluation mode.
+ model.eval()
+
+ iterable = enumerate(val_loader)
+
+ # information for normalization
+ norm_dict = {
+ 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
+ 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
+ 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
+ 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
+ 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
+
+ for i, (input, target_dict) in iterable:
+ batch_size = input.shape[0]
+ input = input.float().to(device)
+ partial_results = {}
+
+ # ----------------------- do visualization step -----------------------
+ with torch.no_grad():
+ output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict)
+
+ index = i
+ ind_img = 0
+ for ind_img in range(batch_size): # range(min(12, batch_size)): # range(12): # [0]: #range(0, batch_size):
+
+ try:
+ if test_name_list is not None:
+ img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
+ img_name = img_name.split('.')[0]
+ else:
+ img_name = str(index) + '_' + str(ind_img)
+ partial_results['img_name'] = img_name
+ visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'],
+ focal_lengths=output_unnorm['flength'],
+ color=0) # 2)
+ # save image with predicted keypoints
+ pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1)
+ pred_unp_maxval = output['keypoints_scores'][ind_img, :, :]
+ pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1)
+ inp_img = input[ind_img, :, :, :].detach().clone()
+ if save_imgs_path is not None:
+ out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png'
+ save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3
+ # save predicted 3d model
+ # (1) front view
+ pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
+ pred_tex_max = np.max(pred_tex, axis=2)
+ partial_results['tex_pred'] = pred_tex
+ if save_imgs_path is not None:
+ out_path = save_imgs_path + '/tex_pred_' + img_name + '.png'
+ plt.imsave(out_path, pred_tex)
+ input_image = input[ind_img, :, :, :].detach().clone()
+ for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m)
+ input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0)
+ im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
+ im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
+ partial_results['comp_pred'] = im_masked
+ if save_imgs_path is not None:
+ out_path = save_imgs_path + '/comp_pred_' + img_name + '.png'
+ plt.imsave(out_path, im_masked)
+ # (2) side view
+ vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :]
+ roll = np.pi / 2 * torch.ones(1).float().to(device)
+ pitch = np.pi / 2 * torch.ones(1).float().to(device)
+ tensor_0 = torch.zeros(1).float().to(device)
+ tensor_1 = torch.ones(1).float().to(device)
+ RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3)
+ RY = torch.stack([
+ torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
+ torch.stack([tensor_0, tensor_1, tensor_0]),
+ torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3)
+ vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((batch_size, -1, 3))
+ vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16
+ visualizations_rot = model.render_vis_nograd(vertices=vertices_rot,
+ focal_lengths=output_unnorm['flength'],
+ color=0) # 2)
+ pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
+ pred_tex_max = np.max(pred_tex, axis=2)
+ partial_results['rot_tex_pred'] = pred_tex
+ if save_imgs_path is not None:
+ out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png'
+ plt.imsave(out_path, pred_tex)
+ render_all = True
+ if render_all:
+ # save input image
+ inp_img = input[ind_img, :, :, :].detach().clone()
+ if save_imgs_path is not None:
+ out_path = save_imgs_path + '/image_' + img_name + '.png'
+ save_input_image(inp_img, out_path)
+ # save posed mesh
+ V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy()
+ Faces = model.smal.f
+ mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False)
+ partial_results['mesh_posed'] = mesh_posed
+ if save_imgs_path is not None:
+ mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj')
+ except:
+ print('pass...')
+ all_results.append(partial_results)
+ if return_results:
+ return all_results
+ else:
+ return
\ No newline at end of file
diff --git a/src/configs/SMAL_configs.py b/src/configs/SMAL_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c977b887b2265ddc44bb85f4713be673bae64ff
--- /dev/null
+++ b/src/configs/SMAL_configs.py
@@ -0,0 +1,165 @@
+
+
+import numpy as np
+import os
+import sys
+
+
+# SMAL_DATA_DIR = '/is/cluster/work/nrueegg/dog_project/pytorch-dogs-inference/src/smal_pytorch/smpl_models/'
+# SMAL_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'smal_pytorch', 'smal_data')
+SMAL_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'smal_data')
+
+# we replace the old SMAL model by a more dog specific model (see BARC cvpr 2022 paper)
+# our model has several differences compared to the original SMAL model, some of them are:
+# - the PCA shape space is recalculated (from partially new data and weighted)
+# - coefficients for limb length changes are allowed (similar to WLDO, we did borrow some of their code)
+# - all dogs have a core of approximately the same length
+# - dogs are centered in their root joint (which is close to the tail base)
+# -> like this the root rotations is always around this joint AND (0, 0, 0)
+# -> before this it would happen that the animal 'slips' from the image middle to the side when rotating it. Now
+# 'trans' also defines the center of the rotation
+# - we correct the back joint locations such that all those joints are more aligned
+SMAL_MODEL_PATH = os.path.join(SMAL_DATA_DIR, 'my_smpl_SMBLD_nbj_v3.pkl')
+UNITY_SMAL_SHAPE_PRIOR_DOGS = os.path.join(SMAL_DATA_DIR, 'my_smpl_data_SMBLD_v3.pkl')
+
+SYMMETRY_INDS_FILE = os.path.join(SMAL_DATA_DIR, 'symmetry_inds.json')
+
+mean_dog_bone_lengths_txt = os.path.join(SMAL_DATA_DIR, 'mean_dog_bone_lengths.txt')
+
+# there exist different keypoint configurations, for example keypoints corresponding to SMAL joints or keypoints defined based on vertex locations
+KEYPOINT_CONFIGURATION = 'green' # green: same as in https://github.com/benjiebob/SMALify/blob/master/config.py
+
+# some vertex indices, (from silvia zuffi´s code, create_projected_images_cats.py)
+KEY_VIDS = np.array(([1068, 1080, 1029, 1226], # left eye
+ [2660, 3030, 2675, 3038], # right eye
+ [910], # mouth low
+ [360, 1203, 1235, 1230], # front left leg, low
+ [3188, 3156, 2327, 3183], # front right leg, low
+ [1976, 1974, 1980, 856], # back left leg, low
+ [3854, 2820, 3852, 3858], # back right leg, low
+ [452, 1811], # tail start
+ [416, 235, 182], # front left leg, top
+ [2156, 2382, 2203], # front right leg, top
+ [829], # back left leg, top
+ [2793], # back right leg, top
+ [60, 114, 186, 59], # throat, close to base of neck
+ [2091, 2037, 2036, 2160], # withers (a bit lower than in reality)
+ [384, 799, 1169, 431], # front left leg, middle
+ [2351, 2763, 2397, 3127], # front right leg, middle
+ [221, 104], # back left leg, middle
+ [2754, 2192], # back right leg, middle
+ [191, 1158, 3116, 2165], # neck
+ [28], # Tail tip
+ [542], # Left Ear
+ [2507], # Right Ear
+ [1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762], # nose tip
+ [0, 464, 465, 726, 1824, 2429, 2430, 2690]), dtype=object) # half tail
+
+# the following vertices are used for visibility only: if one of the vertices is visible,
+# then we assume that the joint is visible! There is some noise, but we don't care, as this is
+# for generation of the synthetic dataset only
+KEY_VIDS_VISIBILITY_ONLY = np.array(([1068, 1080, 1029, 1226, 645], # left eye
+ [2660, 3030, 2675, 3038, 2567], # right eye
+ [910, 11, 5], # mouth low
+ [360, 1203, 1235, 1230, 298, 408, 303, 293, 384], # front left leg, low
+ [3188, 3156, 2327, 3183, 2261, 2271, 2573, 2265], # front right leg, low
+ [1976, 1974, 1980, 856, 559, 851, 556], # back left leg, low
+ [3854, 2820, 3852, 3858, 2524, 2522, 2815, 2072], # back right leg, low
+ [452, 1811, 63, 194, 52, 370, 64], # tail start
+ [416, 235, 182, 440, 8, 80, 73, 112], # front left leg, top
+ [2156, 2382, 2203, 2050, 2052, 2406, 3], # front right leg, top
+ [829, 219, 218, 173, 17, 7, 279], # back left leg, top
+ [2793, 582, 140, 87, 2188, 2147, 2063], # back right leg, top
+ [60, 114, 186, 59, 878, 130, 189, 45], # throat, close to base of neck
+ [2091, 2037, 2036, 2160, 190, 2164], # withers (a bit lower than in reality)
+ [384, 799, 1169, 431, 321, 314, 437, 310, 323], # front left leg, middle
+ [2351, 2763, 2397, 3127, 2278, 2285, 2282, 2275, 2359], # front right leg, middle
+ [221, 104, 105, 97, 103], # back left leg, middle
+ [2754, 2192, 2080, 2251, 2075, 2074], # back right leg, middle
+ [191, 1158, 3116, 2165, 154, 653, 133, 339], # neck
+ [28, 474, 475, 731, 24], # Tail tip
+ [542, 147, 509, 200, 522], # Left Ear
+ [2507,2174, 2122, 2126, 2474], # Right Ear
+ [1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762], # nose tip
+ [0, 464, 465, 726, 1824, 2429, 2430, 2690]), dtype=object) # half tail
+
+# see: https://github.com/benjiebob/SMALify/blob/master/config.py
+# JOINT DEFINITIONS - based on SMAL joints and additional {eyes, ear tips, chin and nose}
+TORSO_JOINTS = [2, 5, 8, 11, 12, 23]
+CANONICAL_MODEL_JOINTS = [
+ 10, 9, 8, # upper_left [paw, middle, top]
+ 20, 19, 18, # lower_left [paw, middle, top]
+ 14, 13, 12, # upper_right [paw, middle, top]
+ 24, 23, 22, # lower_right [paw, middle, top]
+ 25, 31, # tail [start, end]
+ 33, 34, # ear base [left, right]
+ 35, 36, # nose, chin
+ 38, 37, # ear tip [left, right]
+ 39, 40, # eyes [left, right]
+ 6, 11, # withers, throat (throat is inaccurate and withers also)
+ 28] # tail middle
+ # old: 15, 15, # withers, throat (TODO: Labelled same as throat for now), throat
+
+
+
+# the following list gives the indices of the KEY_VIDS_JOINTS that must be taken in order
+# to judge if the CANONICAL_MODEL_JOINTS are visible - those are all approximations!
+CMJ_VISIBILITY_IN_KEY_VIDS = [
+ 3, 14, 8, # left front leg
+ 5, 16, 10, # left rear leg
+ 4, 15, 9, # right front leg
+ 6, 17, 11, # right rear leg
+ 7, 19, # tail front, tail back
+ 20, 21, # ear base (but can not be found in blue, se we take the tip)
+ 2, 2, # mouth (was: 22, 2)
+ 20, 21, # ear tips
+ 1, 0, # eyes
+ 18, # withers, not sure where this point is
+ 12, # throat
+ 23, # mid tail
+ ]
+
+# define which bone lengths are used as input to the 2d-to-3d network
+IDXS_BONES_NO_REDUNDANCY = [6,7,8,9,16,17,18,19,32,1,2,3,4,5,14,15,24,25,26,27,28,29,30,31]
+# load bone lengths of the mean dog (already filtered)
+mean_dog_bone_lengths = []
+with open(mean_dog_bone_lengths_txt, 'r') as f:
+ for line in f:
+ mean_dog_bone_lengths.append(float(line.split('\n')[0]))
+MEAN_DOG_BONE_LENGTHS_NO_RED = np.asarray(mean_dog_bone_lengths)[IDXS_BONES_NO_REDUNDANCY] # (24, )
+
+# Body part segmentation:
+# the body can be segmented based on the bones and for the new dog model also based on the new shapedirs
+# axis_horizontal = self.shapedirs[2, :].reshape((-1, 3))[:, 0]
+# all_indices = np.arange(3889)
+# tail_indices = all_indices[axis_horizontal.detach().cpu().numpy() < 0.0]
+VERTEX_IDS_TAIL = [ 0, 4, 9, 10, 24, 25, 28, 453, 454, 456, 457,
+ 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468,
+ 469, 470, 471, 472, 473, 474, 475, 724, 725, 726, 727,
+ 728, 729, 730, 731, 813, 975, 976, 977, 1109, 1110, 1111,
+ 1811, 1813, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827,
+ 1828, 1835, 1836, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1967,
+ 1968, 1969, 2418, 2419, 2421, 2422, 2423, 2424, 2425, 2426, 2427,
+ 2428, 2429, 2430, 2431, 2432, 2433, 2434, 2435, 2436, 2437, 2438,
+ 2439, 2440, 2688, 2689, 2690, 2691, 2692, 2693, 2694, 2695, 2777,
+ 3067, 3068, 3069, 3842, 3843, 3844, 3845, 3846, 3847]
+
+# same as in https://github.com/benjiebob/WLDO/blob/master/global_utils/config.py
+EVAL_KEYPOINTS = [
+ 0, 1, 2, # left front
+ 3, 4, 5, # left rear
+ 6, 7, 8, # right front
+ 9, 10, 11, # right rear
+ 12, 13, # tail start -> end
+ 14, 15, # left ear, right ear
+ 16, 17, # nose, chin
+ 18, 19] # left tip, right tip
+
+KEYPOINT_GROUPS = {
+ 'legs': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # legs
+ 'tail': [12, 13], # tail
+ 'ears': [14, 15, 18, 19], # ears
+ 'face': [16, 17] # face
+}
+
+
diff --git a/src/configs/anipose_data_info.py b/src/configs/anipose_data_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e7bad68b45cf9926fdfd3ca1b7e1f147e909cfd
--- /dev/null
+++ b/src/configs/anipose_data_info.py
@@ -0,0 +1,74 @@
+from dataclasses import dataclass
+from typing import List
+import json
+import numpy as np
+import os
+
+STATISTICS_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'statistics')
+STATISTICS_PATH = os.path.join(STATISTICS_DATA_DIR, 'statistics_modified_v1.json')
+
+@dataclass
+class DataInfo:
+ rgb_mean: List[float]
+ rgb_stddev: List[float]
+ joint_names: List[str]
+ hflip_indices: List[int]
+ n_joints: int
+ n_keyp: int
+ n_bones: int
+ n_betas: int
+ image_size: int
+ trans_mean: np.ndarray
+ trans_std: np.ndarray
+ flength_mean: np.ndarray
+ flength_std: np.ndarray
+ pose_rot6d_mean: np.ndarray
+ keypoint_weights: List[float]
+
+# SMAL samples 3d statistics
+# statistics like mean values were calculated once when the project was started and they were not changed afterwards anymore
+def load_statistics(statistics_path):
+ with open(statistics_path) as f:
+ statistics = json.load(f)
+ '''new_pose_mean = [[[np.round(val, 2) for val in sublst] for sublst in sublst_big] for sublst_big in statistics['pose_mean']]
+ statistics['pose_mean'] = new_pose_mean
+ j_out = json.dumps(statistics, indent=4) #, sort_keys=True)
+ with open(self.statistics_path, 'w') as file: file.write(j_out)'''
+ new_statistics = {'trans_mean': np.asarray(statistics['trans_mean']),
+ 'trans_std': np.asarray(statistics['trans_std']),
+ 'flength_mean': np.asarray(statistics['flength_mean']),
+ 'flength_std': np.asarray(statistics['flength_std']),
+ 'pose_mean': np.asarray(statistics['pose_mean']),
+ }
+ new_statistics['pose_rot6d_mean'] = new_statistics['pose_mean'][:, :, :2].reshape((-1, 6))
+ return new_statistics
+STATISTICS = load_statistics(STATISTICS_PATH)
+
+AniPose_JOINT_NAMES_swapped = [
+ 'L_F_Paw', 'L_F_Knee', 'L_F_Elbow',
+ 'L_B_Paw', 'L_B_Knee', 'L_B_Elbow',
+ 'R_F_Paw', 'R_F_Knee', 'R_F_Elbow',
+ 'R_B_Paw', 'R_B_Knee', 'R_B_Elbow',
+ 'TailBase', '_Tail_end_', 'L_EarBase', 'R_EarBase',
+ 'Nose', '_Chin_', '_Left_ear_tip_', '_Right_ear_tip_',
+ 'L_Eye', 'R_Eye', 'Withers', 'Throat']
+
+KEYPOINT_WEIGHTS = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2]
+
+COMPLETE_DATA_INFO = DataInfo(
+ rgb_mean=[0.4404, 0.4440, 0.4327], # not sure
+ rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure
+ joint_names=AniPose_JOINT_NAMES_swapped, # AniPose_JOINT_NAMES,
+ hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 23],
+ n_joints = 35,
+ n_keyp = 24, # 20, # 25,
+ n_bones = 24,
+ n_betas = 30, # 10,
+ image_size = 256,
+ trans_mean = STATISTICS['trans_mean'],
+ trans_std = STATISTICS['trans_std'],
+ flength_mean = STATISTICS['flength_mean'],
+ flength_std = STATISTICS['flength_std'],
+ pose_rot6d_mean = STATISTICS['pose_rot6d_mean'],
+ keypoint_weights = KEYPOINT_WEIGHTS
+ )
diff --git a/src/configs/barc_cfg_defaults.py b/src/configs/barc_cfg_defaults.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8be3802c91fe44e974ae453ca574f8cdce5fd80
--- /dev/null
+++ b/src/configs/barc_cfg_defaults.py
@@ -0,0 +1,111 @@
+
+from yacs.config import CfgNode as CN
+import argparse
+import yaml
+import os
+
+abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',))
+
+_C = CN()
+_C.barc_dir = abs_barc_dir
+_C.device = 'cuda'
+
+## path settings
+_C.paths = CN()
+_C.paths.ROOT_OUT_PATH = abs_barc_dir + '/results/'
+_C.paths.ROOT_CHECKPOINT_PATH = abs_barc_dir + '/checkpoint/'
+_C.paths.MODELPATH_NORMFLOW = abs_barc_dir + '/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
+
+## parameter settings
+_C.params = CN()
+_C.params.ARCH = 'hg8'
+_C.params.STRUCTURE_POSE_NET = 'normflow' # 'default' # 'vae'
+_C.params.NF_VERSION = 3
+_C.params.N_JOINTS = 35
+_C.params.N_KEYP = 24 #20
+_C.params.N_SEG = 2
+_C.params.N_PARTSEG = 15
+_C.params.UPSAMPLE_SEG = True
+_C.params.ADD_PARTSEG = True # partseg: for the CVPR paper this part of the network exists, but is not trained (no part labels in StanExt)
+_C.params.N_BETAS = 30 # 10
+_C.params.N_BETAS_LIMBS = 7
+_C.params.N_BONES = 24
+_C.params.N_BREEDS = 121 # 120 breeds plus background
+_C.params.IMG_SIZE = 256
+_C.params.SILH_NO_TAIL = False
+_C.params.KP_THRESHOLD = None
+_C.params.ADD_Z_TO_3D_INPUT = False
+_C.params.N_SEGBPS = 64*2
+_C.params.ADD_SEGBPS_TO_3D_INPUT = True
+_C.params.FIX_FLENGTH = False
+_C.params.RENDER_ALL = True
+_C.params.VLIN = 2
+_C.params.STRUCTURE_Z_TO_B = 'lin'
+_C.params.N_Z_FREE = 64
+_C.params.PCK_THRESH = 0.15
+
+## optimization settings
+_C.optim = CN()
+_C.optim.LR = 5e-4
+_C.optim.SCHEDULE = [150, 175, 200]
+_C.optim.GAMMA = 0.1
+_C.optim.MOMENTUM = 0
+_C.optim.WEIGHT_DECAY = 0
+_C.optim.EPOCHS = 220
+_C.optim.BATCH_SIZE = 12 # keep 12 (needs to be an even number, as we have a custom data sampler)
+_C.optim.TRAIN_PARTS = 'all_without_shapedirs'
+
+## dataset settings
+_C.data = CN()
+_C.data.DATASET = 'stanext24'
+_C.data.V12 = True
+_C.data.SHORTEN_VAL_DATASET_TO = None
+_C.data.VAL_OPT = 'val'
+_C.data.VAL_METRICS = 'no_loss'
+
+# ---------------------------------------
+def update_dependent_vars(cfg):
+ cfg.params.N_CLASSES = cfg.params.N_KEYP + cfg.params.N_SEG
+ if cfg.params.VLIN == 0:
+ cfg.params.NUM_STAGE_COMB = 2
+ cfg.params.NUM_STAGE_HEADS = 1
+ cfg.params.NUM_STAGE_HEADS_POSE = 1
+ cfg.params.TRANS_SEP = False
+ elif cfg.params.VLIN == 1:
+ cfg.params.NUM_STAGE_COMB = 3
+ cfg.params.NUM_STAGE_HEADS = 1
+ cfg.params.NUM_STAGE_HEADS_POSE = 2
+ cfg.params.TRANS_SEP = False
+ elif cfg.params.VLIN == 2:
+ cfg.params.NUM_STAGE_COMB = 3
+ cfg.params.NUM_STAGE_HEADS = 1
+ cfg.params.NUM_STAGE_HEADS_POSE = 2
+ cfg.params.TRANS_SEP = True
+ else:
+ raise NotImplementedError
+ if cfg.params.STRUCTURE_Z_TO_B == '1dconv':
+ cfg.params.N_Z = cfg.params.N_BETAS + cfg.params.N_BETAS_LIMBS
+ else:
+ cfg.params.N_Z = cfg.params.N_Z_FREE
+ return
+
+
+update_dependent_vars(_C)
+global _cfg_global
+_cfg_global = _C.clone()
+
+
+def get_cfg_defaults():
+ # Get a yacs CfgNode object with default values as defined within this file.
+ # Return a clone so that the defaults will not be altered.
+ return _C.clone()
+
+def update_cfg_global_with_yaml(cfg_yaml_file):
+ _cfg_global.merge_from_file(cfg_yaml_file)
+ update_dependent_vars(_cfg_global)
+ return
+
+def get_cfg_global_updated():
+ # return _cfg_global.clone()
+ return _cfg_global
+
diff --git a/src/configs/barc_loss_weights.json b/src/configs/barc_loss_weights.json
new file mode 100644
index 0000000000000000000000000000000000000000..8ddc9e1e6c882431f23b6881c124bf424ae7c3e9
--- /dev/null
+++ b/src/configs/barc_loss_weights.json
@@ -0,0 +1,30 @@
+
+
+
+{
+ "breed_options": [
+ "4"
+ ],
+ "breed": 5.0,
+ "class": 1.0,
+ "models3d": 1.0,
+ "keyp": 0.2,
+ "silh": 50.0,
+ "shape_options": [
+ "smal",
+ "limbs7"
+ ],
+ "shape": [
+ 1e-05,
+ 1
+ ],
+ "poseprior_options": [
+ "normalizing_flow_tiger_logprob"
+ ],
+ "poseprior": 0.1,
+ "poselegssidemovement": 10.0,
+ "flength": 1.0,
+ "partseg": 0,
+ "shapedirs": 0,
+ "pose_0": 0.0
+}
\ No newline at end of file
diff --git a/src/configs/data_info.py b/src/configs/data_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf28608e6361b089d49520e6bf03d142e1aab799
--- /dev/null
+++ b/src/configs/data_info.py
@@ -0,0 +1,115 @@
+from dataclasses import dataclass
+from typing import List
+import json
+import numpy as np
+import os
+import sys
+
+STATISTICS_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'statistics')
+STATISTICS_PATH = os.path.join(STATISTICS_DATA_DIR, 'statistics_modified_v1.json')
+
+@dataclass
+class DataInfo:
+ rgb_mean: List[float]
+ rgb_stddev: List[float]
+ joint_names: List[str]
+ hflip_indices: List[int]
+ n_joints: int
+ n_keyp: int
+ n_bones: int
+ n_betas: int
+ image_size: int
+ trans_mean: np.ndarray
+ trans_std: np.ndarray
+ flength_mean: np.ndarray
+ flength_std: np.ndarray
+ pose_rot6d_mean: np.ndarray
+ keypoint_weights: List[float]
+
+# SMAL samples 3d statistics
+# statistics like mean values were calculated once when the project was started and they were not changed afterwards anymore
+def load_statistics(statistics_path):
+ with open(statistics_path) as f:
+ statistics = json.load(f)
+ '''new_pose_mean = [[[np.round(val, 2) for val in sublst] for sublst in sublst_big] for sublst_big in statistics['pose_mean']]
+ statistics['pose_mean'] = new_pose_mean
+ j_out = json.dumps(statistics, indent=4) #, sort_keys=True)
+ with open(self.statistics_path, 'w') as file: file.write(j_out)'''
+ new_statistics = {'trans_mean': np.asarray(statistics['trans_mean']),
+ 'trans_std': np.asarray(statistics['trans_std']),
+ 'flength_mean': np.asarray(statistics['flength_mean']),
+ 'flength_std': np.asarray(statistics['flength_std']),
+ 'pose_mean': np.asarray(statistics['pose_mean']),
+ }
+ new_statistics['pose_rot6d_mean'] = new_statistics['pose_mean'][:, :, :2].reshape((-1, 6))
+ return new_statistics
+STATISTICS = load_statistics(STATISTICS_PATH)
+
+
+############################################################################
+# for StanExt (original number of keypoints, 20 not 24)
+
+# for keypoint names see: https://github.com/benjiebob/StanfordExtra/blob/master/keypoint_definitions.csv
+StanExt_JOINT_NAMES = [
+ 'Left_front_leg_paw', 'Left_front_leg_middle_joint', 'Left_front_leg_top',
+ 'Left_rear_leg_paw', 'Left_rear_leg_middle_joint', 'Left_rear_leg_top',
+ 'Right_front_leg_paw', 'Right_front_leg_middle_joint', 'Right_front_leg_top',
+ 'Right_rear_leg_paw', 'Right_rear_leg_middle_joint', 'Right_rear_leg_top',
+ 'Tail_start', 'Tail_end', 'Base_of_left_ear', 'Base_of_right_ear',
+ 'Nose', 'Chin', 'Left_ear_tip', 'Right_ear_tip']
+
+KEYPOINT_WEIGHTS = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2]
+
+COMPLETE_DATA_INFO = DataInfo(
+ rgb_mean=[0.4404, 0.4440, 0.4327], # not sure
+ rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure
+ joint_names=StanExt_JOINT_NAMES,
+ hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18],
+ n_joints = 35,
+ n_keyp = 20, # 25,
+ n_bones = 24,
+ n_betas = 30, # 10,
+ image_size = 256,
+ trans_mean = STATISTICS['trans_mean'],
+ trans_std = STATISTICS['trans_std'],
+ flength_mean = STATISTICS['flength_mean'],
+ flength_std = STATISTICS['flength_std'],
+ pose_rot6d_mean = STATISTICS['pose_rot6d_mean'],
+ keypoint_weights = KEYPOINT_WEIGHTS
+ )
+
+
+############################################################################
+# new for StanExt24
+
+# ..., 'Left_eye', 'Right_eye', 'Withers', 'Throat'] # the last 4 keypoints are in the animal_pose dataset, but not StanfordExtra
+StanExt_JOINT_NAMES_24 = [
+ 'Left_front_leg_paw', 'Left_front_leg_middle_joint', 'Left_front_leg_top',
+ 'Left_rear_leg_paw', 'Left_rear_leg_middle_joint', 'Left_rear_leg_top',
+ 'Right_front_leg_paw', 'Right_front_leg_middle_joint', 'Right_front_leg_top',
+ 'Right_rear_leg_paw', 'Right_rear_leg_middle_joint', 'Right_rear_leg_top',
+ 'Tail_start', 'Tail_end', 'Base_of_left_ear', 'Base_of_right_ear',
+ 'Nose', 'Chin', 'Left_ear_tip', 'Right_ear_tip',
+ 'Left_eye', 'Right_eye', 'Withers', 'Throat']
+
+KEYPOINT_WEIGHTS_24 = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2, 1, 1, 0, 0]
+
+COMPLETE_DATA_INFO_24 = DataInfo(
+ rgb_mean=[0.4404, 0.4440, 0.4327], # not sure
+ rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure
+ joint_names=StanExt_JOINT_NAMES_24,
+ hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 23],
+ n_joints = 35,
+ n_keyp = 24, # 20, # 25,
+ n_bones = 24,
+ n_betas = 30, # 10,
+ image_size = 256,
+ trans_mean = STATISTICS['trans_mean'],
+ trans_std = STATISTICS['trans_std'],
+ flength_mean = STATISTICS['flength_mean'],
+ flength_std = STATISTICS['flength_std'],
+ pose_rot6d_mean = STATISTICS['pose_rot6d_mean'],
+ keypoint_weights = KEYPOINT_WEIGHTS_24
+ )
+
+
diff --git a/src/configs/dataset_path_configs.py b/src/configs/dataset_path_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7c46f58a298dba5f037d0f039c91853c30ade64
--- /dev/null
+++ b/src/configs/dataset_path_configs.py
@@ -0,0 +1,21 @@
+
+
+import numpy as np
+import os
+import sys
+
+abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',))
+
+# stanext dataset
+# (1) path to stanext dataset
+STAN_V12_ROOT_DIR = abs_barc_dir + '/datasets/StanfordExtra_V12/'
+IMG_V12_DIR = os.path.join(STAN_V12_ROOT_DIR, 'StanExtV12_Images')
+JSON_V12_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', "StanfordExtra_v12.json")
+STAN_V12_TRAIN_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'train_stanford_StanfordExtra_v12.npy')
+STAN_V12_VAL_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'val_stanford_StanfordExtra_v12.npy')
+STAN_V12_TEST_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'test_stanford_StanfordExtra_v12.npy')
+# (2) path to related data such as breed indices and prepared predictions for withers, throat and eye keypoints
+STANEXT_RELATED_DATA_ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'stanext_related_data')
+
+# image crop dataset (for demo, visualization)
+TEST_IMAGE_CROP_ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'datasets', 'test_image_crops')
diff --git a/src/configs/dog_breeds/dog_breed_class.py b/src/configs/dog_breeds/dog_breed_class.py
new file mode 100644
index 0000000000000000000000000000000000000000..282052164ec6ecb742d91d07ea564cc82cf70ab8
--- /dev/null
+++ b/src/configs/dog_breeds/dog_breed_class.py
@@ -0,0 +1,170 @@
+
+import os
+import warnings
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+import pandas as pd
+import difflib
+import json
+import pickle as pkl
+import csv
+import numpy as np
+
+
+# ----------------------------------------------------------------------------------------------------------------- #
+class DogBreed(object):
+ def __init__(self, abbrev, name_akc=None, name_stanext=None, name_xlsx=None, path_akc=None, path_stanext=None, ind_in_xlsx=None, ind_in_xlsx_matrix=None, ind_in_stanext=None, clade=None):
+ self._abbrev = abbrev
+ self._name_xlsx = name_xlsx
+ self._name_akc = name_akc
+ self._name_stanext = name_stanext
+ self._path_stanext = path_stanext
+ self._additional_names = set()
+ if self._name_akc is not None:
+ self.add_akc_info(name_akc, path_akc)
+ if self._name_stanext is not None:
+ self.add_stanext_info(name_stanext, path_stanext, ind_in_stanext)
+ if self._name_xlsx is not None:
+ self.add_xlsx_info(name_xlsx, ind_in_xlsx, ind_in_xlsx_matrix, clade)
+ def add_xlsx_info(self, name_xlsx, ind_in_xlsx, ind_in_xlsx_matrix, clade):
+ assert (name_xlsx is not None) and (ind_in_xlsx is not None) and (ind_in_xlsx_matrix is not None) and (clade is not None)
+ self._name_xlsx = name_xlsx
+ self._ind_in_xlsx = ind_in_xlsx
+ self._ind_in_xlsx_matrix = ind_in_xlsx_matrix
+ self._clade = clade
+ def add_stanext_info(self, name_stanext, path_stanext, ind_in_stanext):
+ assert (name_stanext is not None) and (path_stanext is not None) and (ind_in_stanext is not None)
+ self._name_stanext = name_stanext
+ self._path_stanext = path_stanext
+ self._ind_in_stanext = ind_in_stanext
+ def add_akc_info(self, name_akc, path_akc):
+ assert (name_akc is not None) and (path_akc is not None)
+ self._name_akc = name_akc
+ self._path_akc = path_akc
+ def add_additional_names(self, name_list):
+ self._additional_names = self._additional_names.union(set(name_list))
+ def add_text_info(self, text_height, text_weight, text_life_exp):
+ self._text_height = text_height
+ self._text_weight = text_weight
+ self._text_life_exp = text_life_exp
+ def get_datasets(self):
+ # all datasets in which this breed is found
+ datasets = set()
+ if self._name_akc is not None:
+ datasets.add('akc')
+ if self._name_stanext is not None:
+ datasets.add('stanext')
+ if self._name_xlsx is not None:
+ datasets.add('xlsx')
+ return datasets
+ def get_names(self):
+ # set of names for this breed
+ names = {self._abbrev, self._name_akc, self._name_stanext, self._name_xlsx, self._path_stanext}.union(self._additional_names)
+ names.discard(None)
+ return names
+ def get_names_as_pointing_dict(self):
+ # each name points to the abbreviation
+ names = self.get_names()
+ my_dict = {}
+ for name in names:
+ my_dict[name] = self._abbrev
+ return my_dict
+ def print_overview(self):
+ # print important information to get an overview of the class instance
+ if self._name_akc is not None:
+ name = self._name_akc
+ elif self._name_xlsx is not None:
+ name = self._name_xlsx
+ else:
+ name = self._name_stanext
+ print('----------------------------------------------------')
+ print('----- dog breed: ' + name )
+ print('----------------------------------------------------')
+ print('[names]')
+ print(self.get_names())
+ print('[datasets]')
+ print(self.get_datasets())
+ # see https://stackoverflow.com/questions/9058305/getting-attributes-of-a-class
+ print('[instance attributes]')
+ for attribute, value in self.__dict__.items():
+ print(attribute, '=', value)
+ def use_dict_to_save_class_instance(self):
+ my_dict = {}
+ for attribute, value in self.__dict__.items():
+ my_dict[attribute] = value
+ return my_dict
+ def use_dict_to_load_class_instance(self, my_dict):
+ for attribute, value in my_dict.items():
+ setattr(self, attribute, value)
+ return
+
+# ----------------------------------------------------------------------------------------------------------------- #
+def get_name_list_from_summary(summary):
+ name_from_abbrev_dict = {}
+ for breed in summary.values():
+ abbrev = breed._abbrev
+ all_names = breed.get_names()
+ name_from_abbrev_dict[abbrev] = list(all_names)
+ return name_from_abbrev_dict
+def get_partial_summary(summary, part):
+ assert part in ['xlsx', 'akc', 'stanext']
+ partial_summary = {}
+ for key, value in summary.items():
+ if (part == 'xlsx' and value._name_xlsx is not None) \
+ or (part == 'akc' and value._name_akc is not None) \
+ or (part == 'stanext' and value._name_stanext is not None):
+ partial_summary[key] = value
+ return partial_summary
+def get_akc_but_not_stanext_partial_summary(summary):
+ partial_summary = {}
+ for key, value in summary.items():
+ if value._name_akc is not None:
+ if value._name_stanext is None:
+ partial_summary[key] = value
+ return partial_summary
+
+# ----------------------------------------------------------------------------------------------------------------- #
+def main_load_dog_breed_classes(path_complete_abbrev_dict_v1, path_complete_summary_breeds_v1):
+ with open(path_complete_abbrev_dict_v1, 'rb') as file:
+ complete_abbrev_dict = pkl.load(file)
+ with open(path_complete_summary_breeds_v1, 'rb') as file:
+ complete_summary_breeds_attributes_only = pkl.load(file)
+
+ complete_summary_breeds = {}
+ for key, value in complete_summary_breeds_attributes_only.items():
+ attributes_only = complete_summary_breeds_attributes_only[key]
+ complete_summary_breeds[key] = DogBreed(abbrev=attributes_only['_abbrev'])
+ complete_summary_breeds[key].use_dict_to_load_class_instance(attributes_only)
+ return complete_abbrev_dict, complete_summary_breeds
+
+
+# ----------------------------------------------------------------------------------------------------------------- #
+def load_similarity_matrix_raw(xlsx_path):
+ # --- LOAD EXCEL FILE FROM DOG BREED PAPER
+ xlsx = pd.read_excel(xlsx_path)
+ # create an array
+ abbrev_indices = {}
+ matrix_raw = np.zeros((168, 168))
+ for ind in range(1, 169):
+ abbrev = xlsx[xlsx.columns[2]][ind]
+ abbrev_indices[abbrev] = ind-1
+ for ind_col in range(0, 168):
+ for ind_row in range(0, 168):
+ matrix_raw[ind_col, ind_row] = float(xlsx[xlsx.columns[3+ind_col]][1+ind_row])
+ return matrix_raw, abbrev_indices
+
+
+
+# ----------------------------------------------------------------------------------------------------------------- #
+# ----------------------------------------------------------------------------------------------------------------- #
+# load the (in advance created) final dict of dog breed classes
+ROOT_PATH_BREED_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..', 'data', 'breed_data')
+path_complete_abbrev_dict_v1 = os.path.join(ROOT_PATH_BREED_DATA, 'complete_abbrev_dict_v2.pkl')
+path_complete_summary_breeds_v1 = os.path.join(ROOT_PATH_BREED_DATA, 'complete_summary_breeds_v2.pkl')
+COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS = main_load_dog_breed_classes(path_complete_abbrev_dict_v1, path_complete_summary_breeds_v1)
+# load similarity matrix, data from:
+# Parker H. G., Dreger D. L., Rimbault M., Davis B. W., Mullen A. B., Carpintero-Ramirez G., and Ostrander E. A.
+# Genomic analyses reveal the influence of geographic origin, migration, and hybridization on modern dog breed
+# development. Cell Reports, 4(19):697–708, 2017.
+xlsx_path = os.path.join(ROOT_PATH_BREED_DATA, 'NIHMS866262-supplement-2.xlsx')
+SIM_MATRIX_RAW, SIM_ABBREV_INDICES = load_similarity_matrix_raw(xlsx_path)
+
diff --git a/src/lifting_to_3d/inn_model_for_shape.py b/src/lifting_to_3d/inn_model_for_shape.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab7c1f18ca603a20406092bdd7163e370d17023
--- /dev/null
+++ b/src/lifting_to_3d/inn_model_for_shape.py
@@ -0,0 +1,61 @@
+
+
+from torch import distributions
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.distributions import Normal
+import numpy as np
+import cv2
+import trimesh
+from tqdm import tqdm
+import warnings
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+import FrEIA.framework as Ff
+import FrEIA.modules as Fm
+
+
+class INNForShape(nn.Module):
+ def __init__(self, n_betas, n_betas_limbs, k_tot=2, betas_scale=1.0, betas_limbs_scale=0.1):
+ super(INNForShape, self).__init__()
+ self.n_betas = n_betas
+ self.n_betas_limbs = n_betas_limbs
+ self.n_dim = n_betas + n_betas_limbs
+ self.betas_scale = betas_scale
+ self.betas_limbs_scale = betas_limbs_scale
+ self.k_tot = 2
+ self.model_inn = self.build_inn_network(self.n_dim, k_tot=self.k_tot)
+
+ def subnet_fc(self, c_in, c_out):
+ subnet = nn.Sequential(nn.Linear(c_in, 64), nn.ReLU(),
+ nn.Linear(64, 64), nn.ReLU(),
+ nn.Linear(64, c_out))
+ return subnet
+
+ def build_inn_network(self, n_input, k_tot=12, verbose=False):
+ coupling_block = Fm.RNVPCouplingBlock
+ nodes = [Ff.InputNode(n_input, name='input')]
+ for k in range(k_tot):
+ nodes.append(Ff.Node(nodes[-1],
+ coupling_block,
+ {'subnet_constructor':self.subnet_fc, 'clamp':2.0},
+ name=F'coupling_{k}'))
+ nodes.append(Ff.Node(nodes[-1],
+ Fm.PermuteRandom,
+ {'seed':k},
+ name=F'permute_{k}'))
+ nodes.append(Ff.OutputNode(nodes[-1], name='output'))
+ model = Ff.ReversibleGraphNet(nodes, verbose=verbose)
+ return model
+
+ def forward(self, latent_rep):
+ shape, _ = self.model_inn(latent_rep, rev=False, jac=False)
+ betas = shape[:, :self.n_betas]*self.betas_scale
+ betas_limbs = shape[:, self.n_betas:]*self.betas_limbs_scale
+ return betas, betas_limbs
+
+ def reverse(self, betas, betas_limbs):
+ shape = torch.cat((betas/self.betas_scale, betas_limbs/self.betas_limbs_scale), dim=1)
+ latent_rep, _ = self.model_inn(shape, rev=True, jac=False)
+ return latent_rep
\ No newline at end of file
diff --git a/src/lifting_to_3d/linear_model.py b/src/lifting_to_3d/linear_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c11266acefcb6bbecd8a748a44cb4915ef4da4b9
--- /dev/null
+++ b/src/lifting_to_3d/linear_model.py
@@ -0,0 +1,297 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# some code from https://raw.githubusercontent.com/weigq/3d_pose_baseline_pytorch/master/src/model.py
+
+
+from __future__ import absolute_import
+from __future__ import print_function
+import torch
+import torch.nn as nn
+
+import os
+import sys
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+# from priors.vae_pose_model.vae_model import VAEmodel
+from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior
+
+
+def weight_init_dangerous(m):
+ # this is dangerous as it may overwrite the normalizing flow weights
+ if isinstance(m, nn.Linear):
+ nn.init.kaiming_normal(m.weight)
+
+
+class Linear(nn.Module):
+ def __init__(self, linear_size, p_dropout=0.5):
+ super(Linear, self).__init__()
+ self.l_size = linear_size
+
+ self.relu = nn.ReLU(inplace=True)
+ self.dropout = nn.Dropout(p_dropout)
+
+ self.w1 = nn.Linear(self.l_size, self.l_size)
+ self.batch_norm1 = nn.BatchNorm1d(self.l_size)
+
+ self.w2 = nn.Linear(self.l_size, self.l_size)
+ self.batch_norm2 = nn.BatchNorm1d(self.l_size)
+
+ def forward(self, x):
+ y = self.w1(x)
+ y = self.batch_norm1(y)
+ y = self.relu(y)
+ y = self.dropout(y)
+ y = self.w2(y)
+ y = self.batch_norm2(y)
+ y = self.relu(y)
+ y = self.dropout(y)
+ out = x + y
+ return out
+
+
+class LinearModel(nn.Module):
+ def __init__(self,
+ linear_size=1024,
+ num_stage=2,
+ p_dropout=0.5,
+ input_size=16*2,
+ output_size=16*3):
+ super(LinearModel, self).__init__()
+ self.linear_size = linear_size
+ self.p_dropout = p_dropout
+ self.num_stage = num_stage
+ # input
+ self.input_size = input_size # 2d joints: 16 * 2
+ # output
+ self.output_size = output_size # 3d joints: 16 * 3
+ # process input to linear size
+ self.w1 = nn.Linear(self.input_size, self.linear_size)
+ self.batch_norm1 = nn.BatchNorm1d(self.linear_size)
+ self.linear_stages = []
+ for l in range(num_stage):
+ self.linear_stages.append(Linear(self.linear_size, self.p_dropout))
+ self.linear_stages = nn.ModuleList(self.linear_stages)
+ # post-processing
+ self.w2 = nn.Linear(self.linear_size, self.output_size)
+ # helpers (relu and dropout)
+ self.relu = nn.ReLU(inplace=True)
+ self.dropout = nn.Dropout(self.p_dropout)
+
+ def forward(self, x):
+ # pre-processing
+ y = self.w1(x)
+ y = self.batch_norm1(y)
+ y = self.relu(y)
+ y = self.dropout(y)
+ # linear layers
+ for i in range(self.num_stage):
+ y = self.linear_stages[i](y)
+ # post-processing
+ y = self.w2(y)
+ return y
+
+
+class LinearModelComplete(nn.Module):
+ def __init__(self,
+ linear_size=1024,
+ num_stage_comb=2,
+ num_stage_heads=1,
+ num_stage_heads_pose=1,
+ trans_sep=False,
+ p_dropout=0.5,
+ input_size=16*2,
+ intermediate_size=1024,
+ output_info=None,
+ n_joints=25,
+ n_z=512,
+ add_z_to_3d_input=False,
+ n_segbps=64*2,
+ add_segbps_to_3d_input=False,
+ structure_pose_net='default',
+ fix_vae_weights=True,
+ nf_version=None): # 0): n_silh_enc
+ super(LinearModelComplete, self).__init__()
+ if add_z_to_3d_input:
+ self.n_z_to_add = n_z # 512
+ else:
+ self.n_z_to_add = 0
+ if add_segbps_to_3d_input:
+ self.n_segbps_to_add = n_segbps # 64
+ else:
+ self.n_segbps_to_add = 0
+ self.input_size = input_size
+ self.linear_size = linear_size
+ self.p_dropout = p_dropout
+ self.num_stage_comb = num_stage_comb
+ self.num_stage_heads = num_stage_heads
+ self.num_stage_heads_pose = num_stage_heads_pose
+ self.trans_sep = trans_sep
+ self.input_size = input_size
+ self.intermediate_size = intermediate_size
+ self.structure_pose_net = structure_pose_net
+ self.fix_vae_weights = fix_vae_weights # only relevant if structure_pose_net='vae'
+ self.nf_version = nf_version
+ if output_info is None:
+ pose = {'name': 'pose', 'n': n_joints*6, 'out_shape':[n_joints, 6]}
+ cam = {'name': 'flength', 'n': 1}
+ if self.trans_sep:
+ translation_xy = {'name': 'trans_xy', 'n': 2}
+ translation_z = {'name': 'trans_z', 'n': 1}
+ self.output_info = [pose, translation_xy, translation_z, cam]
+ else:
+ translation = {'name': 'trans', 'n': 3}
+ self.output_info = [pose, translation, cam]
+ if self.structure_pose_net == 'vae' or self.structure_pose_net == 'normflow':
+ global_pose = {'name': 'global_pose', 'n': 1*6, 'out_shape':[1, 6]}
+ self.output_info.append(global_pose)
+ else:
+ self.output_info = output_info
+ self.linear_combined = LinearModel(linear_size=self.linear_size,
+ num_stage=self.num_stage_comb,
+ p_dropout=p_dropout,
+ input_size=self.input_size + self.n_segbps_to_add + self.n_z_to_add, ######
+ output_size=self.intermediate_size)
+ self.output_info_linear_models = []
+ for ind_el, element in enumerate(self.output_info):
+ if element['name'] == 'pose':
+ num_stage = self.num_stage_heads_pose
+ if self.structure_pose_net == 'default':
+ output_size_pose_lin = element['n']
+ elif self.structure_pose_net == 'vae':
+ # load vae decoder
+ self.pose_vae_model = VAEmodel()
+ self.pose_vae_model.initialize_with_pretrained_weights()
+ # define the input size of the vae decoder
+ output_size_pose_lin = self.pose_vae_model.latent_size
+ elif self.structure_pose_net == 'normflow':
+ # the following will automatically be initialized
+ self.pose_normflow_model = NormalizingFlowPrior(nf_version=self.nf_version)
+ output_size_pose_lin = element['n'] - 6 # no global rotation
+ else:
+ raise NotImplementedError
+ self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size,
+ num_stage=num_stage,
+ p_dropout=p_dropout,
+ input_size=self.intermediate_size,
+ output_size=output_size_pose_lin))
+ else:
+ if element['name'] == 'global_pose':
+ num_stage = self.num_stage_heads_pose
+ else:
+ num_stage = self.num_stage_heads
+ self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size,
+ num_stage=num_stage,
+ p_dropout=p_dropout,
+ input_size=self.intermediate_size,
+ output_size=element['n']))
+ element['linear_model_index'] = ind_el
+ self.output_info_linear_models = nn.ModuleList(self.output_info_linear_models)
+
+ def forward(self, x):
+ device = x.device
+ # combined stage
+ if x.shape[1] == self.input_size + self.n_segbps_to_add + self.n_z_to_add:
+ y = self.linear_combined(x)
+ elif x.shape[1] == self.input_size + self.n_segbps_to_add:
+ x_mod = torch.cat((x, torch.normal(0, 1, size=(x.shape[0], self.n_z_to_add)).to(device)), dim=1)
+ y = self.linear_combined(x_mod)
+ else:
+ print(x.shape)
+ print(self.input_size)
+ print(self.n_segbps_to_add)
+ print(self.n_z_to_add)
+ raise ValueError
+ # heads
+ results = {}
+ results_trans = {}
+ for element in self.output_info:
+ linear_model = self.output_info_linear_models[element['linear_model_index']]
+ if element['name'] == 'pose':
+ if self.structure_pose_net == 'default':
+ results['pose'] = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1]))
+ normflow_z = None
+ elif self.structure_pose_net == 'vae':
+ res_lin = linear_model(y)
+ if self.fix_vae_weights:
+ self.pose_vae_model.requires_grad_(False) # let gradients flow through but don't update the parameters
+ res_vae = self.pose_vae_model.inference(feat=res_lin)
+ self.pose_vae_model.requires_grad_(True)
+ else:
+ res_vae = self.pose_vae_model.inference(feat=res_lin)
+ res_pose_not_glob = res_vae.reshape((-1, element['out_shape'][0], element['out_shape'][1]))
+ normflow_z = None
+ elif self.structure_pose_net == 'normflow':
+ normflow_z = linear_model(y)*0.1
+ self.pose_normflow_model.requires_grad_(False) # let gradients flow though but don't update the parameters
+ res_pose_not_glob = self.pose_normflow_model.run_backwards(z=normflow_z).reshape((-1, element['out_shape'][0]-1, element['out_shape'][1]))
+ else:
+ raise NotImplementedError
+ elif element['name'] == 'global_pose':
+ res_pose_glob = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1]))
+ elif element['name'] == 'trans_xy' or element['name'] == 'trans_z':
+ results_trans[element['name']] = linear_model(y)
+ else:
+ results[element['name']] = linear_model(y)
+ if self.trans_sep:
+ results['trans'] = torch.cat((results_trans['trans_xy'], results_trans['trans_z']), dim=1)
+ # prepare pose including global rotation
+ if self.structure_pose_net == 'vae':
+ # results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob), dim=1)
+ results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob[:, 1:, :]), dim=1)
+ elif self.structure_pose_net == 'normflow':
+ results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob[:, :, :]), dim=1)
+ # return a dictionary which contains all results
+ results['normflow_z'] = normflow_z
+ return results # this is a dictionary
+
+
+
+
+
+# ------------------------------------------
+# for pretraining of the 3d model only:
+# (see combined_model/model_shape_v2.py)
+
+class Wrapper_LinearModelComplete(nn.Module):
+ def __init__(self,
+ linear_size=1024,
+ num_stage_comb=2,
+ num_stage_heads=1,
+ num_stage_heads_pose=1,
+ trans_sep=False,
+ p_dropout=0.5,
+ input_size=16*2,
+ intermediate_size=1024,
+ output_info=None,
+ n_joints=25,
+ n_z=512,
+ add_z_to_3d_input=False,
+ n_segbps=64*2,
+ add_segbps_to_3d_input=False,
+ structure_pose_net='default',
+ fix_vae_weights=True,
+ nf_version=None):
+ self.add_segbps_to_3d_input = add_segbps_to_3d_input
+ super(Wrapper_LinearModelComplete, self).__init__()
+ self.model_3d = LinearModelComplete(linear_size=linear_size,
+ num_stage_comb=num_stage_comb,
+ num_stage_heads=num_stage_heads,
+ num_stage_heads_pose=num_stage_heads_pose,
+ trans_sep=trans_sep,
+ p_dropout=p_dropout, # 0.5,
+ input_size=input_size,
+ intermediate_size=intermediate_size,
+ output_info=output_info,
+ n_joints=n_joints,
+ n_z=n_z,
+ add_z_to_3d_input=add_z_to_3d_input,
+ n_segbps=n_segbps,
+ add_segbps_to_3d_input=add_segbps_to_3d_input,
+ structure_pose_net=structure_pose_net,
+ fix_vae_weights=fix_vae_weights,
+ nf_version=nf_version)
+ def forward(self, input_vec):
+ # input_vec = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1)
+ # predict 3d parameters (those are normalized, we need to correct mean and std in a next step)
+ output = self.model_3d(input_vec)
+ return output
\ No newline at end of file
diff --git a/src/lifting_to_3d/utils/geometry_utils.py b/src/lifting_to_3d/utils/geometry_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a83466b7212cab58d6f4c0f88ae98a206583a3f8
--- /dev/null
+++ b/src/lifting_to_3d/utils/geometry_utils.py
@@ -0,0 +1,236 @@
+
+import torch
+from torch.nn import functional as F
+import numpy as np
+from torch import nn
+
+
+def geodesic_loss(R, Rgt):
+ # see: Silvia tiger pose model 3d code
+ num_joints = R.shape[1]
+ RT = R.permute(0,1,3,2)
+ A = torch.matmul(RT.view(-1,3,3),Rgt.view(-1,3,3))
+ # torch.trace works only for 2D tensors
+ n = A.shape[0]
+ po_loss = 0
+ eps = 1e-7
+ T = torch.sum(A[:,torch.eye(3).bool()],1)
+ theta = torch.clamp(0.5*(T-1), -1+eps, 1-eps)
+ angles = torch.acos(theta)
+ loss = torch.sum(angles)/(n*num_joints)
+ return loss
+
+class geodesic_loss_R(nn.Module):
+ def __init__(self,reduction='mean'):
+ super(geodesic_loss_R, self).__init__()
+ self.reduction = reduction
+ self.eps = 1e-6
+
+ # batch geodesic loss for rotation matrices
+ def bgdR(self,bRgts,bRps):
+ #return((bRgts - bRps)**2.).mean()
+ return geodesic_loss(bRgts, bRps)
+
+ def forward(self, ypred, ytrue):
+ theta = geodesic_loss(ypred,ytrue)
+ if self.reduction == 'mean':
+ return torch.mean(theta)
+ else:
+ return theta
+
+def batch_rodrigues_numpy(theta):
+ """ Code adapted from spin
+ Convert axis-angle representation to rotation matrix.
+ Remark:
+ this leads to the same result as kornia.angle_axis_to_rotation_matrix(theta)
+ Args:
+ theta: size = [B, 3]
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ l1norm = np.linalg.norm(theta + 1e-8, ord = 2, axis = 1)
+ # angle = np.unsqueeze(l1norm, -1)
+ angle = l1norm.reshape((-1, 1))
+ # normalized = np.div(theta, angle)
+ normalized = theta / angle
+ angle = angle * 0.5
+ v_cos = np.cos(angle)
+ v_sin = np.sin(angle)
+ # quat = np.cat([v_cos, v_sin * normalized], dim = 1)
+ quat = np.concatenate([v_cos, v_sin * normalized], axis = 1)
+ return quat_to_rotmat_numpy(quat)
+
+def quat_to_rotmat_numpy(quat):
+ """Code from: https://github.com/nkolot/SPIN/blob/master/utils/geometry.py
+ Convert quaternion coefficients to rotation matrix.
+ Args:
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ norm_quat = quat
+ # norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
+ norm_quat = norm_quat/np.linalg.norm(norm_quat, ord=2, axis=1, keepdims=True)
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
+ B = quat.shape[0]
+ # w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ w2, x2, y2, z2 = w**2, x**2, y**2, z**2
+ wx, wy, wz = w*x, w*y, w*z
+ xy, xz, yz = x*y, x*z, y*z
+ rotMat = np.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], axis=1).reshape(B, 3, 3)
+ return rotMat
+
+
+def batch_rodrigues(theta):
+ """Code from: https://github.com/nkolot/SPIN/blob/master/utils/geometry.py
+ Convert axis-angle representation to rotation matrix.
+ Remark:
+ this leads to the same result as kornia.angle_axis_to_rotation_matrix(theta)
+ Args:
+ theta: size = [B, 3]
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
+ angle = torch.unsqueeze(l1norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
+ return quat_to_rotmat(quat)
+
+def batch_rot2aa(Rs, epsilon=1e-7):
+ """ Code from: https://github.com/vchoutas/expose/blob/dffc38d62ad3817481d15fe509a93c2bb606cb8b/expose/utils/rotation_utils.py#L55
+ Rs is B x 3 x 3
+ void cMathUtil::RotMatToAxisAngle(const tMatrix& mat, tVector& out_axis,
+ double& out_theta)
+ {
+ double c = 0.5 * (mat(0, 0) + mat(1, 1) + mat(2, 2) - 1);
+ c = cMathUtil::Clamp(c, -1.0, 1.0);
+ out_theta = std::acos(c);
+ if (std::abs(out_theta) < 0.00001)
+ {
+ out_axis = tVector(0, 0, 1, 0);
+ }
+ else
+ {
+ double m21 = mat(2, 1) - mat(1, 2);
+ double m02 = mat(0, 2) - mat(2, 0);
+ double m10 = mat(1, 0) - mat(0, 1);
+ double denom = std::sqrt(m21 * m21 + m02 * m02 + m10 * m10);
+ out_axis[0] = m21 / denom;
+ out_axis[1] = m02 / denom;
+ out_axis[2] = m10 / denom;
+ out_axis[3] = 0;
+ }
+ }
+ """
+ cos = 0.5 * (torch.einsum('bii->b', [Rs]) - 1)
+ cos = torch.clamp(cos, -1 + epsilon, 1 - epsilon)
+ theta = torch.acos(cos)
+ m21 = Rs[:, 2, 1] - Rs[:, 1, 2]
+ m02 = Rs[:, 0, 2] - Rs[:, 2, 0]
+ m10 = Rs[:, 1, 0] - Rs[:, 0, 1]
+ denom = torch.sqrt(m21 * m21 + m02 * m02 + m10 * m10 + epsilon)
+ axis0 = torch.where(torch.abs(theta) < 0.00001, m21, m21 / denom)
+ axis1 = torch.where(torch.abs(theta) < 0.00001, m02, m02 / denom)
+ axis2 = torch.where(torch.abs(theta) < 0.00001, m10, m10 / denom)
+ return theta.unsqueeze(1) * torch.stack([axis0, axis1, axis2], 1)
+
+def quat_to_rotmat(quat):
+ """Code from: https://github.com/nkolot/SPIN/blob/master/utils/geometry.py
+ Convert quaternion coefficients to rotation matrix.
+ Args:
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ norm_quat = quat
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w*x, w*y, w*z
+ xy, xz, yz = x*y, x*z, y*z
+
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
+ return rotMat
+
+def rot6d_to_rotmat(rot6d):
+ """ Code from: https://github.com/nkolot/SPIN/blob/master/utils/geometry.py
+ Convert 6D rotation representation to 3x3 rotation matrix.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Input:
+ (B,6) Batch of 6-D rotation representations
+ Output:
+ (B,3,3) Batch of corresponding rotation matrices
+ """
+ rot6d = rot6d.view(-1,3,2)
+ a1 = rot6d[:, :, 0]
+ a2 = rot6d[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.cross(b1, b2)
+ rotmat = torch.stack((b1, b2, b3), dim=-1)
+ return rotmat
+
+def rotmat_to_rot6d(rotmat):
+ """ Convert 3x3 rotation matrix to 6D rotation representation.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Input:
+ (B,3,3) Batch of corresponding rotation matrices
+ Output:
+ (B,6) Batch of 6-D rotation representations
+ """
+ rot6d = rotmat[:, :, :2].reshape((-1, 6))
+ return rot6d
+
+
+def main():
+ # rotation matrix and 6d representation
+ # see "On the Continuity of Rotation Representations in Neural Networks"
+ from pyquaternion import Quaternion
+ batch_size = 5
+ rotmat = np.zeros((batch_size, 3, 3))
+ for ind in range(0, batch_size):
+ rotmat[ind, :, :] = Quaternion.random().rotation_matrix
+ rotmat_torch = torch.Tensor(rotmat)
+ rot6d = rotmat_to_rot6d(rotmat_torch)
+ rotmat_rec = rot6d_to_rotmat(rot6d)
+ print('..................... 1 ....................')
+ print(rotmat_torch[0, :, :])
+ print(rotmat_rec[0, :, :])
+ print('Conversion from rotmat to rot6d and inverse are ok!')
+ # rotation matrix and axis angle representation
+ import kornia
+ input = torch.rand(1, 3)
+ output = kornia.angle_axis_to_rotation_matrix(input)
+ input_rec = kornia.rotation_matrix_to_angle_axis(output)
+ print('..................... 2 ....................')
+ print(input)
+ print(input_rec)
+ print('Kornia implementation for rotation_matrix_to_angle_axis is wrong!!!!')
+ # For non-differential conversions use scipy:
+ # https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.html
+ from scipy.spatial.transform import Rotation as R
+ r = R.from_matrix(rotmat[0, :, :])
+ print('..................... 3 ....................')
+ print(r.as_matrix())
+ print(r.as_rotvec())
+ print(r.as_quaternion)
+ # one might furthermore have a look at:
+ # https://github.com/silviazuffi/smalst/blob/master/utils/transformations.py
+
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/src/metrics/metrics.py b/src/metrics/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffa1ae1c00bd286f55a4ede8565dc3eb619162a9
--- /dev/null
+++ b/src/metrics/metrics.py
@@ -0,0 +1,74 @@
+# code from: https://github.com/benjiebob/WLDO/blob/master/wldo_regressor/metrics.py
+
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+IMG_RES = 256 # in WLDO it is 224
+
+class Metrics():
+
+ @staticmethod
+ def PCK_thresh(
+ pred_keypoints, gt_keypoints,
+ gtseg, has_seg,
+ thresh, idxs, biggs=False):
+
+ pred_keypoints, gt_keypoints, gtseg = pred_keypoints[has_seg], gt_keypoints[has_seg], gtseg[has_seg]
+
+ if idxs is None:
+ idxs = list(range(pred_keypoints.shape[1]))
+
+ idxs = np.array(idxs).astype(int)
+
+ pred_keypoints = pred_keypoints[:, idxs]
+ gt_keypoints = gt_keypoints[:, idxs]
+
+ if biggs:
+ keypoints_gt = ((gt_keypoints + 1.0) * 0.5) * IMG_RES
+ dist = torch.norm(pred_keypoints - keypoints_gt[:, :, [1, 0]], dim = -1)
+ else:
+ keypoints_gt = gt_keypoints # (0 to IMG_SIZE)
+ dist = torch.norm(pred_keypoints - keypoints_gt[:, :, :2], dim = -1)
+
+ seg_area = torch.sum(gtseg.reshape(gtseg.shape[0], -1), dim = -1).unsqueeze(-1)
+
+ hits = (dist / torch.sqrt(seg_area)) < thresh
+ total_visible = torch.sum(gt_keypoints[:, :, -1], dim = -1)
+ pck = torch.sum(hits.float() * gt_keypoints[:, :, -1], dim = -1) / total_visible
+
+ return pck
+
+ @staticmethod
+ def PCK(
+ pred_keypoints, keypoints,
+ gtseg, has_seg,
+ thresh_range=[0.15],
+ idxs:list=None,
+ biggs=False):
+ """Calc PCK with same method as in eval.
+ idxs = optional list of subset of keypoints to index from
+ """
+ cumulative_pck = []
+ for thresh in thresh_range:
+ pck = Metrics.PCK_thresh(
+ pred_keypoints, keypoints,
+ gtseg, has_seg, thresh, idxs,
+ biggs=biggs)
+ cumulative_pck.append(pck)
+ pck_mean = torch.stack(cumulative_pck, dim = 0).mean(dim=0)
+ return pck_mean
+
+ @staticmethod
+ def IOU(synth_silhouettes, gt_seg, img_border_mask, mask):
+ for i in range(mask.shape[0]):
+ synth_silhouettes[i] *= mask[i]
+ # Do not penalize parts of the segmentation outside the img range
+ gt_seg = (gt_seg * img_border_mask) + synth_silhouettes * (1.0 - img_border_mask)
+ intersection = torch.sum((synth_silhouettes * gt_seg).reshape(synth_silhouettes.shape[0], -1), dim = -1)
+ union = torch.sum(((synth_silhouettes + gt_seg).reshape(synth_silhouettes.shape[0], -1) > 0.0).float(), dim = -1)
+ acc_IOU_SCORE = intersection / union
+ if torch.isnan(acc_IOU_SCORE).sum() > 0:
+ import pdb; pdb.set_trace()
+ return acc_IOU_SCORE
\ No newline at end of file
diff --git a/src/priors/normalizing_flow_prior/normalizing_flow_prior.py b/src/priors/normalizing_flow_prior/normalizing_flow_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf60fe51d722c31d7b045a637e1b57d4b577091
--- /dev/null
+++ b/src/priors/normalizing_flow_prior/normalizing_flow_prior.py
@@ -0,0 +1,115 @@
+
+from torch import distributions
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.distributions import Normal
+import numpy as np
+import cv2
+import trimesh
+from tqdm import tqdm
+
+import warnings
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+import FrEIA.framework as Ff
+import FrEIA.modules as Fm
+from configs.barc_cfg_defaults import get_cfg_global_updated
+
+
+class NormalizingFlowPrior(nn.Module):
+ def __init__(self, nf_version=None):
+ super(NormalizingFlowPrior, self).__init__()
+ # the normalizing flow network takes as input a vector of size (35-1)*6 which is
+ # [all joints except root joint]*6. At the moment the rotation is represented as 6D
+ # representation, which is actually not ideal. Nevertheless, in practice the
+ # results seem to be ok.
+ n_dim = (35 - 1) * 6
+ self.param_dict = self.get_version_param_dict(nf_version)
+ self.model_inn = self.build_inn_network(n_dim, k_tot=self.param_dict['k_tot'])
+ self.initialize_with_pretrained_weights()
+
+ def get_version_param_dict(self, nf_version):
+ # we had trained several version of the normalizing flow pose prior, here we just provide
+ # the option that was user for the cvpr 2022 paper (nf_version=3)
+ if nf_version == 3:
+ param_dict = {
+ 'k_tot': 2,
+ 'path_pretrained': get_cfg_global_updated().paths.MODELPATH_NORMFLOW,
+ 'subnet_fc_type': '3_64'}
+ else:
+ print(nf_version)
+ raise ValueError
+ return param_dict
+
+ def initialize_with_pretrained_weights(self, weight_path=None):
+ # The normalizing flow pose prior is pretrained separately. Afterwards all weights
+ # are kept fixed. Here we load those pretrained weights.
+ if weight_path is None:
+ weight_path = self.param_dict['path_pretrained']
+ print(' normalizing flow pose prior: loading {}..'.format(weight_path))
+ pretrained_dict = torch.load(weight_path)['model_state_dict']
+ self.model_inn.load_state_dict(pretrained_dict, strict=True)
+
+ def subnet_fc(self, c_in, c_out):
+ if self.param_dict['subnet_fc_type']=='3_512':
+ subnet = nn.Sequential(nn.Linear(c_in, 512), nn.ReLU(),
+ nn.Linear(512, 512), nn.ReLU(),
+ nn.Linear(512, c_out))
+ elif self.param_dict['subnet_fc_type']=='3_64':
+ subnet = nn.Sequential(nn.Linear(c_in, 64), nn.ReLU(),
+ nn.Linear(64, 64), nn.ReLU(),
+ nn.Linear(64, c_out))
+ return subnet
+
+ def build_inn_network(self, n_input, k_tot=12, verbose=False):
+ coupling_block = Fm.RNVPCouplingBlock
+ nodes = [Ff.InputNode(n_input, name='input')]
+ for k in range(k_tot):
+ nodes.append(Ff.Node(nodes[-1],
+ coupling_block,
+ {'subnet_constructor':self.subnet_fc, 'clamp':2.0},
+ name=F'coupling_{k}'))
+ nodes.append(Ff.Node(nodes[-1],
+ Fm.PermuteRandom,
+ {'seed':k},
+ name=F'permute_{k}'))
+ nodes.append(Ff.OutputNode(nodes[-1], name='output'))
+ model = Ff.ReversibleGraphNet(nodes, verbose=verbose)
+ return model
+
+ def calculate_loss_from_z(self, z, type='square'):
+ assert type in ['square', 'neg_log_prob']
+ if type == 'square':
+ loss = (z**2).mean() # * 0.00001
+ elif type == 'neg_log_prob':
+ means = torch.zeros((z.shape[0], z.shape[1]), dtype=z.dtype, device=z.device)
+ stds = torch.ones((z.shape[0], z.shape[1]), dtype=z.dtype, device=z.device)
+ normal_distribution = Normal(means, stds)
+ log_prob = normal_distribution.log_prob(z)
+ loss = - log_prob.mean()
+ return loss
+
+ def calculate_loss(self, poses_rot6d, type='square'):
+ assert type in ['square', 'neg_log_prob']
+ poses_rot6d_noglob = poses_rot6d[:, 1:, :].reshape((-1, 34*6))
+ z, _ = self.model_inn(poses_rot6d_noglob, rev=False, jac=False)
+ loss = self.calculate_loss_from_z(z, type=type)
+ return loss
+
+ def forward(self, poses_rot6d):
+ # from pose to latent pose representation z
+ # poses_rot6d has shape (bs, 34, 6)
+ poses_rot6d_noglob = poses_rot6d[:, 1:, :].reshape((-1, 34*6))
+ z, _ = self.model_inn(poses_rot6d_noglob, rev=False, jac=False)
+ return z
+
+ def run_backwards(self, z):
+ # from latent pose representation z to pose
+ poses_rot6d_noglob, _ = self.model_inn(z, rev=True, jac=False)
+ return poses_rot6d_noglob
+
+
+
+
+
\ No newline at end of file
diff --git a/src/priors/shape_prior.py b/src/priors/shape_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..f62ebc5d656aa6829427746d9582700db38481cc
--- /dev/null
+++ b/src/priors/shape_prior.py
@@ -0,0 +1,40 @@
+
+# some parts of the code adapted from https://github.com/benjiebob/WLDO and https://github.com/benjiebob/SMALify
+
+import numpy as np
+import torch
+import pickle as pkl
+
+
+
+class ShapePrior(torch.nn.Module):
+ def __init__(self, prior_path):
+ super(ShapePrior, self).__init__()
+ try:
+ with open(prior_path, 'r') as f:
+ res = pkl.load(f)
+ except (UnicodeDecodeError, TypeError) as e:
+ with open(prior_path, 'rb') as file:
+ u = pkl._Unpickler(file)
+ u.encoding = 'latin1'
+ res = u.load()
+ betas_mean = res['dog_cluster_mean']
+ betas_cov = res['dog_cluster_cov']
+ single_gaussian_inv_covs = np.linalg.inv(betas_cov + 1e-5 * np.eye(betas_cov.shape[0]))
+ single_gaussian_precs = torch.tensor(np.linalg.cholesky(single_gaussian_inv_covs)).float()
+ single_gaussian_means = torch.tensor(betas_mean).float()
+ self.register_buffer('single_gaussian_precs', single_gaussian_precs) # (20, 20)
+ self.register_buffer('single_gaussian_means', single_gaussian_means) # (20)
+ use_ind_tch = torch.from_numpy(np.ones(single_gaussian_means.shape[0], dtype=bool)).float() # .to(device)
+ self.register_buffer('use_ind_tch', use_ind_tch)
+
+ def forward(self, betas_smal_orig, use_singe_gaussian=False):
+ n_betas_smal = betas_smal_orig.shape[1]
+ device = betas_smal_orig.device
+ use_ind_tch_corrected = self.use_ind_tch * torch.cat((torch.ones_like(self.use_ind_tch[:n_betas_smal]), torch.zeros_like(self.use_ind_tch[n_betas_smal:])))
+ samples = torch.cat((betas_smal_orig, torch.zeros((betas_smal_orig.shape[0], self.single_gaussian_means.shape[0]-n_betas_smal)).float().to(device)), dim=1)
+ mean_sub = samples - self.single_gaussian_means.unsqueeze(0)
+ single_gaussian_precs_corr = self.single_gaussian_precs * use_ind_tch_corrected[:, None] * use_ind_tch_corrected[None, :]
+ res = torch.tensordot(mean_sub, single_gaussian_precs_corr, dims = ([1], [0]))
+ res_final_mean_2 = torch.mean(res ** 2)
+ return res_final_mean_2
diff --git a/src/smal_pytorch/renderer/differentiable_renderer.py b/src/smal_pytorch/renderer/differentiable_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3d76f1f34a16ee6b559e18d95ecaca4fa267b31
--- /dev/null
+++ b/src/smal_pytorch/renderer/differentiable_renderer.py
@@ -0,0 +1,280 @@
+
+# part of the code from
+# https://github.com/benjiebob/SMALify/blob/master/smal_fitter/p3d_renderer.py
+
+import torch
+import torch.nn.functional as F
+from scipy.io import loadmat
+import numpy as np
+# import config
+
+import pytorch3d
+from pytorch3d.structures import Meshes
+from pytorch3d.renderer import (
+ PerspectiveCameras, look_at_view_transform, look_at_rotation,
+ RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
+ PointLights, HardPhongShader, SoftSilhouetteShader, Materials, Textures,
+ DirectionalLights
+)
+from pytorch3d.renderer import TexturesVertex, SoftPhongShader
+from pytorch3d.io import load_objs_as_meshes
+
+MESH_COLOR_0 = [0, 172, 223]
+MESH_COLOR_1 = [172, 223, 0]
+
+
+'''
+Explanation of the shift between projection results from opendr and pytorch3d:
+ (0, 0, ?) will be projected to 127.5 (pytorch3d) instead of 128 (opendr)
+ imagine you have an image of size 4:
+ middle of the first pixel is 0
+ middle of the last pixel is 3
+ => middle of the imgae would be 1.5 and not 2!
+ so in order to go from pytorch3d predictions to opendr we would calculate: p_odr = p_p3d * (128/127.5)
+To reproject points (p3d) by hand according to this pytorch3d renderer we would do the following steps:
+ 1.) build camera matrix
+ K = np.array([[flength, 0, c_x],
+ [0, flength, c_y],
+ [0, 0, 1]], np.float)
+ 2.) we don't need to add extrinsics, as the mesh comes with translation (which is
+ added within smal_pytorch). all 3d points are already in the camera coordinate system.
+ -> projection reduces to p2d_proj = K*p3d
+ 3.) convert to pytorch3d conventions (0 in the middle of the first pixel)
+ p2d_proj_pytorch3d = p2d_proj / image_size * (image_size-1.)
+renderer.py - project_points_p3d: shows an example of what is described above, but
+ same focal length for the whole batch
+
+'''
+
+class SilhRenderer(torch.nn.Module):
+ def __init__(self, image_size, adapt_R_wldo=False):
+ super(SilhRenderer, self).__init__()
+ # see: https://pytorch3d.org/files/fit_textured_mesh.py, line 315
+ # adapt_R=True is True for all my experiments
+ # image_size: one number, integer
+ # -----
+ # set mesh color
+ self.register_buffer('mesh_color_0', torch.FloatTensor(MESH_COLOR_0))
+ self.register_buffer('mesh_color_1', torch.FloatTensor(MESH_COLOR_1))
+ # prepare extrinsics, which in our case don't change
+ R = torch.Tensor(np.eye(3)).float()[None, :, :]
+ T = torch.Tensor(np.zeros((1, 3))).float()
+ if adapt_R_wldo:
+ R[0, 0, 0] = -1
+ else: # used for all my own experiments
+ R[0, 0, 0] = -1
+ R[0, 1, 1] = -1
+ self.register_buffer('R', R)
+ self.register_buffer('T', T)
+ # prepare that part of the intrinsics which does not change either
+ # principal_point_prep = torch.Tensor([self.image_size / 2., self.image_size / 2.]).float()[None, :].float().to(device)
+ # image_size_prep = torch.Tensor([self.image_size, self.image_size]).float()[None, :].float().to(device)
+ self.img_size_scalar = image_size
+ self.register_buffer('image_size', torch.Tensor([image_size, image_size]).float()[None, :].float())
+ self.register_buffer('principal_point', torch.Tensor([image_size / 2., image_size / 2.]).float()[None, :].float())
+ # Rasterization settings for differentiable rendering, where the blur_radius
+ # initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable
+ # Renderer for Image-based 3D Reasoning', ICCV 2019
+ self.blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
+ self.raster_settings_soft = RasterizationSettings(
+ image_size=image_size, # 128
+ blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params.sigma,
+ faces_per_pixel=100) #50,
+ # Renderer for Image-based 3D Reasoning', body part segmentation
+ self.blend_params_parts = BlendParams(sigma=2*1e-4, gamma=1e-4)
+ self.raster_settings_soft_parts = RasterizationSettings(
+ image_size=image_size, # 128
+ blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params_parts.sigma,
+ faces_per_pixel=60) #50,
+ # settings for visualization renderer
+ self.raster_settings_vis = RasterizationSettings(
+ image_size=image_size,
+ blur_radius=0.0,
+ faces_per_pixel=1)
+
+ def _get_cam(self, focal_lengths):
+ device = focal_lengths.device
+ bs = focal_lengths.shape[0]
+ if pytorch3d.__version__ == '0.2.5':
+ cameras = PerspectiveCameras(device=device,
+ focal_length=focal_lengths.repeat((1, 2)),
+ principal_point=self.principal_point.repeat((bs, 1)),
+ R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)),
+ image_size=self.image_size.repeat((bs, 1)))
+ elif pytorch3d.__version__ == '0.6.1':
+ cameras = PerspectiveCameras(device=device, in_ndc=False,
+ focal_length=focal_lengths.repeat((1, 2)),
+ principal_point=self.principal_point.repeat((bs, 1)),
+ R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)),
+ image_size=self.image_size.repeat((bs, 1)))
+ else:
+ print('this part depends on the version of pytorch3d, code was developed with 0.2.5')
+ raise ValueError
+ return cameras
+
+ def _get_visualization_from_mesh(self, mesh, cameras, lights=None):
+ # color renderer for visualization
+ with torch.no_grad():
+ device = mesh.device
+ # renderer for visualization
+ if lights is None:
+ lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]])
+ vis_renderer = MeshRenderer(
+ rasterizer=MeshRasterizer(
+ cameras=cameras,
+ raster_settings=self.raster_settings_vis),
+ shader=HardPhongShader(
+ device=device,
+ cameras=cameras,
+ lights=lights))
+ # render image:
+ visualization = vis_renderer(mesh).permute(0, 3, 1, 2)[:, :3, :, :]
+ return visualization
+
+
+ def calculate_vertex_visibility(self, vertices, faces, focal_lengths, soft=False):
+ tex = torch.ones_like(vertices) * self.mesh_color_0 # (1, V, 3)
+ textures = Textures(verts_rgb=tex)
+ mesh = Meshes(verts=vertices, faces=faces, textures=textures)
+ cameras = self._get_cam(focal_lengths)
+ # NEW: use the rasterizer to check vertex visibility
+ # see: https://github.com/facebookresearch/pytorch3d/issues/126
+ # Get a rasterizer
+ if soft:
+ rasterizer = MeshRasterizer(cameras=cameras,
+ raster_settings=self.raster_settings_soft)
+ else:
+ rasterizer = MeshRasterizer(cameras=cameras,
+ raster_settings=self.raster_settings_vis)
+ # Get the output from rasterization
+ fragments = rasterizer(mesh)
+ # pix_to_face is of shape (N, H, W, 1)
+ pix_to_face = fragments.pix_to_face
+ # (F, 3) where F is the total number of faces across all the meshes in the batch
+ packed_faces = mesh.faces_packed()
+ # (V, 3) where V is the total number of verts across all the meshes in the batch
+ packed_verts = mesh.verts_packed()
+ vertex_visibility_map = torch.zeros(packed_verts.shape[0]) # (V,)
+ # Indices of unique visible faces
+ visible_faces = pix_to_face.unique() # [0] # (num_visible_faces )
+ # Get Indices of unique visible verts using the vertex indices in the faces
+ visible_verts_idx = packed_faces[visible_faces] # (num_visible_faces, 3)
+ unique_visible_verts_idx = torch.unique(visible_verts_idx) # (num_visible_verts, )
+ # Update visibility indicator to 1 for all visible vertices
+ vertex_visibility_map[unique_visible_verts_idx] = 1.0
+ # since all meshes have the same amount of vertices, we can reshape the result
+ bs = vertices.shape[0]
+ vertex_visibility_map_resh = vertex_visibility_map.reshape((bs, -1))
+ return pix_to_face, vertex_visibility_map_resh
+
+
+ def get_torch_meshes(self, vertices, faces, color=0):
+ # create pytorch mesh
+ if color == 0:
+ mesh_color = self.mesh_color_0
+ else:
+ mesh_color = self.mesh_color_1
+ tex = torch.ones_like(vertices) * mesh_color # (1, V, 3)
+ textures = Textures(verts_rgb=tex)
+ mesh = Meshes(verts=vertices, faces=faces, textures=textures)
+ return mesh
+
+
+ def get_visualization_nograd(self, vertices, faces, focal_lengths, color=0):
+ # vertices: torch.Size([bs, 3889, 3])
+ # faces: torch.Size([bs, 7774, 3]), int
+ # focal_lengths: torch.Size([bs, 1])
+ device = vertices.device
+ # create cameras
+ cameras = self._get_cam(focal_lengths)
+ # create pytorch mesh
+ if color == 0:
+ mesh_color = self.mesh_color_0 # blue
+ elif color == 1:
+ mesh_color = self.mesh_color_1
+ elif color == 2:
+ MESH_COLOR_2 = [240, 250, 240] # white
+ mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device)
+ elif color == 3:
+ # MESH_COLOR_3 = [223, 0, 172] # pink
+ # MESH_COLOR_3 = [245, 245, 220] # beige
+ MESH_COLOR_3 = [166, 173, 164]
+ mesh_color = torch.FloatTensor(MESH_COLOR_3).to(device)
+ else:
+ MESH_COLOR_2 = [240, 250, 240]
+ mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device)
+ tex = torch.ones_like(vertices) * mesh_color # (1, V, 3)
+ textures = Textures(verts_rgb=tex)
+ mesh = Meshes(verts=vertices, faces=faces, textures=textures)
+ # render mesh (no gradients)
+ # lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]])
+ # lights = PointLights(device=device, location=[[2.0, 2.0, -2.0]])
+ lights = DirectionalLights(device=device, direction=[[0.0, -5.0, -10.0]])
+ visualization = self._get_visualization_from_mesh(mesh, cameras, lights=lights)
+ return visualization
+
+ def project_points(self, points, focal_lengths=None, cameras=None):
+ # points: torch.Size([bs, n_points, 3])
+ # either focal_lengths or cameras is needed:
+ # focal_lenghts: torch.Size([bs, 1])
+ # cameras: pytorch camera, for example PerspectiveCameras()
+ bs = points.shape[0]
+ device = points.device
+ screen_size = self.image_size.repeat((bs, 1))
+ if cameras is None:
+ cameras = self._get_cam(focal_lengths)
+ if pytorch3d.__version__ == '0.2.5':
+ proj_points_orig = cameras.transform_points_screen(points, screen_size)[:, :, [1, 0]] # used in the original virtuel environment (for cvpr BARC submission)
+ elif pytorch3d.__version__ == '0.6.1':
+ proj_points_orig = cameras.transform_points_screen(points)[:, :, [1, 0]]
+ else:
+ print('this part depends on the version of pytorch3d, code was developed with 0.2.5')
+ raise ValueError
+ # flip, otherwise the 1st and 2nd row are exchanged compared to the ground truth
+ proj_points = torch.flip(proj_points_orig, [2])
+ # --- project points 'manually'
+ # j_proj = project_points_p3d(image_size, focal_length, points, device)
+ return proj_points
+
+ def forward(self, vertices, points, faces, focal_lengths, color=None):
+ # vertices: torch.Size([bs, 3889, 3])
+ # points: torch.Size([bs, n_points, 3]) (or None)
+ # faces: torch.Size([bs, 7774, 3]), int
+ # focal_lengths: torch.Size([bs, 1])
+ # color: if None we don't render a visualization, else it should
+ # either be 0 or 1
+ # ---> important: results are around 0.5 pixels off compared to chumpy!
+ # have a look at renderer.py for an explanation
+ # create cameras
+ cameras = self._get_cam(focal_lengths)
+ # create pytorch mesh
+ if color is None or color == 0:
+ mesh_color = self.mesh_color_0
+ else:
+ mesh_color = self.mesh_color_1
+ tex = torch.ones_like(vertices) * mesh_color # (1, V, 3)
+ textures = Textures(verts_rgb=tex)
+ mesh = Meshes(verts=vertices, faces=faces, textures=textures)
+ # silhouette renderer
+ renderer_silh = MeshRenderer(
+ rasterizer=MeshRasterizer(
+ cameras=cameras,
+ raster_settings=self.raster_settings_soft),
+ shader=SoftSilhouetteShader(blend_params=self.blend_params))
+ # project silhouette
+ silh_images = renderer_silh(mesh)[..., -1].unsqueeze(1)
+ # project points
+ if points is None:
+ proj_points = None
+ else:
+ proj_points = self.project_points(points=points, cameras=cameras)
+ if color is not None:
+ # color renderer for visualization (no gradients)
+ visualization = self._get_visualization_from_mesh(mesh, cameras)
+ return silh_images, proj_points, visualization
+ else:
+ return silh_images, proj_points
+
+
+
+
diff --git a/src/smal_pytorch/smal_model/batch_lbs.py b/src/smal_pytorch/smal_model/batch_lbs.py
new file mode 100644
index 0000000000000000000000000000000000000000..98e9d321cf721ac3a47504bd49843b9979a22e71
--- /dev/null
+++ b/src/smal_pytorch/smal_model/batch_lbs.py
@@ -0,0 +1,295 @@
+'''
+Adjusted version of other PyTorch implementation of the SMAL/SMPL model
+see:
+ 1.) https://github.com/silviazuffi/smalst/blob/master/smal_model/smal_torch.py
+ 2.) https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py
+'''
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import torch
+import numpy as np
+
+
+def batch_skew(vec, batch_size=None):
+ """
+ vec is N x 3, batch_size is int
+
+ returns N x 3 x 3. Skew_sym version of each matrix.
+ """
+ device = vec.device
+ if batch_size is None:
+ batch_size = vec.shape.as_list()[0]
+ col_inds = torch.LongTensor([1, 2, 3, 5, 6, 7])
+ indices = torch.reshape(torch.reshape(torch.arange(0, batch_size) * 9, [-1, 1]) + col_inds, [-1, 1])
+ updates = torch.reshape(
+ torch.stack(
+ [
+ -vec[:, 2], vec[:, 1], vec[:, 2], -vec[:, 0], -vec[:, 1],
+ vec[:, 0]
+ ],
+ dim=1), [-1])
+ out_shape = [batch_size * 9]
+ res = torch.Tensor(np.zeros(out_shape[0])).to(device=device)
+ res[np.array(indices.flatten())] = updates
+ res = torch.reshape(res, [batch_size, 3, 3])
+
+ return res
+
+
+
+def batch_rodrigues(theta):
+ """
+ Theta is Nx3
+ """
+ device = theta.device
+ batch_size = theta.shape[0]
+
+ angle = (torch.norm(theta + 1e-8, p=2, dim=1)).unsqueeze(-1)
+ r = (torch.div(theta, angle)).unsqueeze(-1)
+
+ angle = angle.unsqueeze(-1)
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+
+ outer = torch.matmul(r, r.transpose(1,2))
+
+ eyes = torch.eye(3).unsqueeze(0).repeat([batch_size, 1, 1]).to(device=device)
+ H = batch_skew(r, batch_size=batch_size)
+ R = cos * eyes + (1 - cos) * outer + sin * H
+
+ return R
+
+def batch_lrotmin(theta):
+ """
+ Output of this is used to compute joint-to-pose blend shape mapping.
+ Equation 9 in SMPL paper.
+
+
+ Args:
+ pose: `Tensor`, N x 72 vector holding the axis-angle rep of K joints.
+ This includes the global rotation so K=24
+
+ Returns
+ diff_vec : `Tensor`: N x 207 rotation matrix of 23=(K-1) joints with identity subtracted.,
+ """
+ # Ignore global rotation
+ theta = theta[:,3:]
+
+ Rs = batch_rodrigues(torch.reshape(theta, [-1,3]))
+ lrotmin = torch.reshape(Rs - torch.eye(3), [-1, 207])
+
+ return lrotmin
+
+def batch_global_rigid_transformation(Rs, Js, parent, rotate_base=False):
+ """
+ Computes absolute joint locations given pose.
+
+ rotate_base: if True, rotates the global rotation by 90 deg in x axis.
+ if False, this is the original SMPL coordinate.
+
+ Args:
+ Rs: N x 24 x 3 x 3 rotation vector of K joints
+ Js: N x 24 x 3, joint locations before posing
+ parent: 24 holding the parent id for each index
+
+ Returns
+ new_J : `Tensor`: N x 24 x 3 location of absolute joints
+ A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS.
+ """
+ device = Rs.device
+ if rotate_base:
+ print('Flipping the SMPL coordinate frame!!!!')
+ rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
+ rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile
+ root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x)
+ else:
+ root_rotation = Rs[:, 0, :, :]
+
+ # Now Js is N x 24 x 3 x 1
+ Js = Js.unsqueeze(-1)
+ N = Rs.shape[0]
+
+ def make_A(R, t):
+ # Rs is N x 3 x 3, ts is N x 3 x 1
+ R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0))
+ t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(device=device)], 1)
+ return torch.cat([R_homo, t_homo], 2)
+
+ A0 = make_A(root_rotation, Js[:, 0])
+ results = [A0]
+ for i in range(1, parent.shape[0]):
+ j_here = Js[:, i] - Js[:, parent[i]]
+ A_here = make_A(Rs[:, i], j_here)
+ res_here = torch.matmul(
+ results[parent[i]], A_here)
+ results.append(res_here)
+
+ # 10 x 24 x 4 x 4
+ results = torch.stack(results, dim=1)
+
+ new_J = results[:, :, :3, 3]
+
+ # --- Compute relative A: Skinning is based on
+ # how much the bone moved (not the final location of the bone)
+ # but (final_bone - init_bone)
+ # ---
+ Js_w0 = torch.cat([Js, torch.zeros([N, 35, 1, 1]).to(device=device)], 2)
+ init_bone = torch.matmul(results, Js_w0)
+ # Append empty 4 x 3:
+ init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0))
+ A = results - init_bone
+
+ return new_J, A
+
+
+#########################################################################################
+
+def get_bone_length_scales(part_list, betas_logscale):
+ leg_joints = list(range(7,11)) + list(range(11,15)) + list(range(17,21)) + list(range(21,25))
+ tail_joints = list(range(25, 32))
+ ear_joints = [33, 34]
+ neck_joints = [15, 6] # ?
+ core_joints = [4, 5] # ?
+ mouth_joints = [16, 32]
+ log_scales = torch.zeros(betas_logscale.shape[0], 35).to(betas_logscale.device)
+ for ind, part in enumerate(part_list):
+ if part == 'legs_l':
+ log_scales[:, leg_joints] = betas_logscale[:, ind][:, None]
+ elif part == 'tail_l':
+ log_scales[:, tail_joints] = betas_logscale[:, ind][:, None]
+ elif part == 'ears_l':
+ log_scales[:, ear_joints] = betas_logscale[:, ind][:, None]
+ elif part == 'neck_l':
+ log_scales[:, neck_joints] = betas_logscale[:, ind][:, None]
+ elif part == 'core_l':
+ log_scales[:, core_joints] = betas_logscale[:, ind][:, None]
+ elif part == 'head_l':
+ log_scales[:, mouth_joints] = betas_logscale[:, ind][:, None]
+ else:
+ pass
+ all_scales = torch.exp(log_scales)
+ return all_scales[:, 1:] # don't count root
+
+def get_beta_scale_mask(part_list):
+ # which joints belong to which bodypart
+ leg_joints = list(range(7,11)) + list(range(11,15)) + list(range(17,21)) + list(range(21,25))
+ tail_joints = list(range(25, 32))
+ ear_joints = [33, 34]
+ neck_joints = [15, 6] # ?
+ core_joints = [4, 5] # ?
+ mouth_joints = [16, 32]
+ n_b_log = len(part_list) #betas_logscale.shape[1] # 8 # 6
+ beta_scale_mask = torch.zeros(35, 3, n_b_log) # .to(betas_logscale.device)
+ for ind, part in enumerate(part_list):
+ if part == 'legs_l':
+ beta_scale_mask[leg_joints, [2], [ind]] = 1.0 # Leg lengthening
+ elif part == 'legs_f':
+ beta_scale_mask[leg_joints, [0], [ind]] = 1.0 # Leg fatness
+ beta_scale_mask[leg_joints, [1], [ind]] = 1.0 # Leg fatness
+ elif part == 'tail_l':
+ beta_scale_mask[tail_joints, [0], [ind]] = 1.0 # Tail lengthening
+ elif part == 'tail_f':
+ beta_scale_mask[tail_joints, [1], [ind]] = 1.0 # Tail fatness
+ beta_scale_mask[tail_joints, [2], [ind]] = 1.0 # Tail fatness
+ elif part == 'ears_y':
+ beta_scale_mask[ear_joints, [1], [ind]] = 1.0 # Ear y
+ elif part == 'ears_l':
+ beta_scale_mask[ear_joints, [2], [ind]] = 1.0 # Ear z
+ elif part == 'neck_l':
+ beta_scale_mask[neck_joints, [0], [ind]] = 1.0 # Neck lengthening
+ elif part == 'neck_f':
+ beta_scale_mask[neck_joints, [1], [ind]] = 1.0 # Neck fatness
+ beta_scale_mask[neck_joints, [2], [ind]] = 1.0 # Neck fatness
+ elif part == 'core_l':
+ beta_scale_mask[core_joints, [0], [ind]] = 1.0 # Core lengthening
+ # beta_scale_mask[core_joints, [1], [ind]] = 1.0 # Core fatness (height)
+ elif part == 'core_fs':
+ beta_scale_mask[core_joints, [2], [ind]] = 1.0 # Core fatness (side)
+ elif part == 'head_l':
+ beta_scale_mask[mouth_joints, [0], [ind]] = 1.0 # Head lengthening
+ elif part == 'head_f':
+ beta_scale_mask[mouth_joints, [1], [ind]] = 1.0 # Head fatness 0
+ beta_scale_mask[mouth_joints, [2], [ind]] = 1.0 # Head fatness 1
+ else:
+ print(part + ' not available')
+ raise ValueError
+ beta_scale_mask = torch.transpose(
+ beta_scale_mask.reshape(35*3, n_b_log), 0, 1)
+ return beta_scale_mask
+
+def batch_global_rigid_transformation_biggs(Rs, Js, parent, scale_factors_3x3, rotate_base = False, betas_logscale=None, opts=None):
+ """
+ Computes absolute joint locations given pose.
+
+ rotate_base: if True, rotates the global rotation by 90 deg in x axis.
+ if False, this is the original SMPL coordinate.
+
+ Args:
+ Rs: N x 24 x 3 x 3 rotation vector of K joints
+ Js: N x 24 x 3, joint locations before posing
+ parent: 24 holding the parent id for each index
+
+ Returns
+ new_J : `Tensor`: N x 24 x 3 location of absolute joints
+ A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS.
+ """
+ if rotate_base:
+ print('Flipping the SMPL coordinate frame!!!!')
+ rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
+ rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile
+ root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x)
+ else:
+ root_rotation = Rs[:, 0, :, :]
+
+ # Now Js is N x 24 x 3 x 1
+ Js = Js.unsqueeze(-1)
+ N = Rs.shape[0]
+
+ Js_orig = Js.clone()
+
+ def make_A(R, t):
+ # Rs is N x 3 x 3, ts is N x 3 x 1
+ R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0))
+ t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(Rs.device)], 1)
+ return torch.cat([R_homo, t_homo], 2)
+
+ A0 = make_A(root_rotation, Js[:, 0])
+ results = [A0]
+ for i in range(1, parent.shape[0]):
+ j_here = Js[:, i] - Js[:, parent[i]]
+ try:
+ s_par_inv = torch.inverse(scale_factors_3x3[:, parent[i]])
+ except:
+ # import pdb; pdb.set_trace()
+ s_par_inv = torch.max(scale_factors_3x3[:, parent[i]], 0.01*torch.eye((3))[None, :, :].to(scale_factors_3x3.device))
+ rot = Rs[:, i]
+ s = scale_factors_3x3[:, i]
+
+ rot_new = s_par_inv @ rot @ s
+
+ A_here = make_A(rot_new, j_here)
+ res_here = torch.matmul(
+ results[parent[i]], A_here)
+
+ results.append(res_here)
+
+ # 10 x 24 x 4 x 4
+ results = torch.stack(results, dim=1)
+
+ # scale updates
+ new_J = results[:, :, :3, 3]
+
+ # --- Compute relative A: Skinning is based on
+ # how much the bone moved (not the final location of the bone)
+ # but (final_bone - init_bone)
+ # ---
+ Js_w0 = torch.cat([Js_orig, torch.zeros([N, 35, 1, 1]).to(Rs.device)], 2)
+ init_bone = torch.matmul(results, Js_w0)
+ # Append empty 4 x 3:
+ init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0))
+ A = results - init_bone
+
+ return new_J, A
\ No newline at end of file
diff --git a/src/smal_pytorch/smal_model/smal_basics.py b/src/smal_pytorch/smal_model/smal_basics.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd2e71ce5c5bd1d087041aed79a376eae749ad24
--- /dev/null
+++ b/src/smal_pytorch/smal_model/smal_basics.py
@@ -0,0 +1,82 @@
+'''
+Adjusted version of other PyTorch implementation of the SMAL/SMPL model
+see:
+ 1.) https://github.com/silviazuffi/smalst/blob/master/smal_model/smal_torch.py
+ 2.) https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py
+'''
+
+import os
+import pickle as pkl
+import json
+import numpy as np
+import pickle as pkl
+
+import os
+import sys
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+from configs.SMAL_configs import SMAL_DATA_DIR, SYMMETRY_INDS_FILE
+
+# model_dir = 'smalst/smpl_models/'
+# FILE_DIR = os.path.dirname(os.path.realpath(__file__))
+model_dir = SMAL_DATA_DIR # os.path.join(FILE_DIR, '..', 'smpl_models/')
+symmetry_inds_file = SYMMETRY_INDS_FILE # os.path.join(FILE_DIR, '..', 'smpl_models/symmetry_inds.json')
+with open(symmetry_inds_file) as f:
+ symmetry_inds_dict = json.load(f)
+LEFT_INDS = np.asarray(symmetry_inds_dict['left_inds'])
+RIGHT_INDS = np.asarray(symmetry_inds_dict['right_inds'])
+CENTER_INDS = np.asarray(symmetry_inds_dict['center_inds'])
+
+
+def get_symmetry_indices():
+ sym_dict = {'left': LEFT_INDS,
+ 'right': RIGHT_INDS,
+ 'center': CENTER_INDS}
+ return sym_dict
+
+def verify_symmetry(shapedirs, center_inds=CENTER_INDS, left_inds=LEFT_INDS, right_inds=RIGHT_INDS):
+ # shapedirs: (3889, 3, n_sh)
+ assert (shapedirs[center_inds, 1, :] == 0.0).all()
+ assert (shapedirs[right_inds, 1, :] == -shapedirs[left_inds, 1, :]).all()
+ return
+
+def from_shapedirs_to_shapedirs_half(shapedirs, center_inds=CENTER_INDS, left_inds=LEFT_INDS, right_inds=RIGHT_INDS, verify=False):
+ # shapedirs: (3889, 3, n_sh)
+ # shapedirs_half: (2012, 3, n_sh)
+ selected_inds = np.concatenate((center_inds, left_inds), axis=0)
+ shapedirs_half = shapedirs[selected_inds, :, :]
+ if verify:
+ verify_symmetry(shapedirs)
+ else:
+ shapedirs_half[:center_inds.shape[0], 1, :] = 0.0
+ return shapedirs_half
+
+def from_shapedirs_half_to_shapedirs(shapedirs_half, center_inds=CENTER_INDS, left_inds=LEFT_INDS, right_inds=RIGHT_INDS):
+ # shapedirs_half: (2012, 3, n_sh)
+ # shapedirs: (3889, 3, n_sh)
+ shapedirs = np.zeros((center_inds.shape[0] + 2*left_inds.shape[0], 3, shapedirs_half.shape[2]))
+ shapedirs[center_inds, :, :] = shapedirs_half[:center_inds.shape[0], :, :]
+ shapedirs[left_inds, :, :] = shapedirs_half[center_inds.shape[0]:, :, :]
+ shapedirs[right_inds, :, :] = shapedirs_half[center_inds.shape[0]:, :, :]
+ shapedirs[right_inds, 1, :] = - shapedirs_half[center_inds.shape[0]:, 1, :]
+ return shapedirs
+
+def align_smal_template_to_symmetry_axis(v, subtract_mean=True):
+ # These are the indexes of the points that are on the symmetry axis
+ I = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 37, 55, 119, 120, 163, 209, 210, 211, 213, 216, 227, 326, 395, 452, 578, 910, 959, 964, 975, 976, 977, 1172, 1175, 1176, 1178, 1194, 1243, 1739, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1870, 1919, 1960, 1961, 1965, 1967, 2003]
+ if subtract_mean:
+ v = v - np.mean(v)
+ y = np.mean(v[I,1])
+ v[:,1] = v[:,1] - y
+ v[I,1] = 0
+ left_inds = LEFT_INDS
+ right_inds = RIGHT_INDS
+ center_inds = CENTER_INDS
+ v[right_inds, :] = np.array([1,-1,1])*v[left_inds, :]
+ try:
+ assert(len(left_inds) == len(right_inds))
+ except:
+ import pdb; pdb.set_trace()
+ return v, left_inds, right_inds, center_inds
+
+
+
diff --git a/src/smal_pytorch/smal_model/smal_torch_new.py b/src/smal_pytorch/smal_model/smal_torch_new.py
new file mode 100644
index 0000000000000000000000000000000000000000..5562a33b97849116d827a5213e81c40ece705b70
--- /dev/null
+++ b/src/smal_pytorch/smal_model/smal_torch_new.py
@@ -0,0 +1,313 @@
+"""
+PyTorch implementation of the SMAL/SMPL model
+see:
+ 1.) https://github.com/silviazuffi/smalst/blob/master/smal_model/smal_torch.py
+ 2.) https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py
+main changes compared to SMALST and WLDO:
+ * new model
+ (/ps/scratch/nrueegg/new_projects/side_packages/SMALify/new_smal_pca/results/my_tposeref_results_3/)
+ dogs are part of the pca to create the model
+ al meshes are centered around their root joint
+ the animals are all scaled such that their body length (butt to breast) is 1
+ X_init = np.concatenate((vertices_dogs, vertices_smal), axis=0) # vertices_dogs
+ X = []
+ for ind in range(0, X_init.shape[0]):
+ X_tmp, _, _, _ = align_smal_template_to_symmetry_axis(X_init[ind, :, :], subtract_mean=True) # not sure if this is necessary
+ X.append(X_tmp)
+ X = np.asarray(X)
+ # define points which will be used for normalization
+ idxs_front = [6, 16, 8, 964] # [1172, 6, 16, 8, 964]
+ idxs_back = [174, 2148, 175, 2149] # not in the middle, but pairs
+ reg_j = np.asarray(dd['J_regressor'].todense())
+ # normalize the meshes such that X_frontback_dist is 1 and the root joint is in the center (0, 0, 0)
+ X_front = X[:, idxs_front, :].mean(axis=1)
+ X_back = X[:, idxs_back, :].mean(axis=1)
+ X_frontback_dist = np.sqrt(((X_front - X_back)**2).sum(axis=1))
+ X = X / X_frontback_dist[:, None, None]
+ X_j0 = np.sum(X[:, reg_j[0, :]>0, :] * reg_j[0, (reg_j[0, :]>0)][None, :, None], axis=1)
+ X = X - X_j0[:, None, :]
+ * add limb length changes the same way as in WLDO
+ * overall scale factor is added
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import torch
+import chumpy as ch
+import os.path
+from torch import nn
+from torch.autograd import Variable
+import pickle as pkl
+from .batch_lbs import batch_rodrigues, batch_global_rigid_transformation, batch_global_rigid_transformation_biggs, get_bone_length_scales, get_beta_scale_mask
+
+from .smal_basics import align_smal_template_to_symmetry_axis, get_symmetry_indices
+
+import os
+import sys
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+from configs.SMAL_configs import KEY_VIDS, CANONICAL_MODEL_JOINTS, IDXS_BONES_NO_REDUNDANCY, SMAL_MODEL_PATH
+
+from smal_pytorch.utils import load_vertex_colors
+
+
+# There are chumpy variables so convert them to numpy.
+def undo_chumpy(x):
+ return x if isinstance(x, np.ndarray) else x.r
+
+# class SMAL(object):
+class SMAL(nn.Module):
+ def __init__(self, pkl_path=SMAL_MODEL_PATH, n_betas=None, template_name='neutral', use_smal_betas=True, logscale_part_list=None):
+ super(SMAL, self).__init__()
+
+ if logscale_part_list is None:
+ self.logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l']
+ self.betas_scale_mask = get_beta_scale_mask(part_list=self.logscale_part_list)
+ self.num_betas_logscale = len(self.logscale_part_list)
+
+ self.use_smal_betas = use_smal_betas
+
+ # -- Load SMPL params --
+ try:
+ with open(pkl_path, 'r') as f:
+ dd = pkl.load(f)
+ except (UnicodeDecodeError, TypeError) as e:
+ with open(pkl_path, 'rb') as file:
+ u = pkl._Unpickler(file)
+ u.encoding = 'latin1'
+ dd = u.load()
+
+ self.f = dd['f']
+ self.register_buffer('faces', torch.from_numpy(self.f.astype(int)))
+
+ # get the correct template (mean shape)
+ if template_name=='neutral':
+ v_template = dd['v_template']
+ v = v_template
+ else:
+ raise NotImplementedError
+
+ # Mean template vertices
+ self.register_buffer('v_template', torch.Tensor(v))
+ # Size of mesh [Number of vertices, 3]
+ self.size = [self.v_template.shape[0], 3]
+ self.num_betas = dd['shapedirs'].shape[-1]
+ # symmetry indices
+ self.sym_ids_dict = get_symmetry_indices()
+
+ # Shape blend shape basis
+ shapedir = np.reshape(undo_chumpy(dd['shapedirs']), [-1, self.num_betas]).T
+ shapedir.flags['WRITEABLE'] = True # not sure why this is necessary
+ self.register_buffer('shapedirs', torch.Tensor(shapedir))
+
+ # Regressor for joint locations given shape
+ self.register_buffer('J_regressor', torch.Tensor(dd['J_regressor'].T.todense()))
+
+ # Pose blend shape basis
+ num_pose_basis = dd['posedirs'].shape[-1]
+
+ posedirs = np.reshape(undo_chumpy(dd['posedirs']), [-1, num_pose_basis]).T
+ self.register_buffer('posedirs', torch.Tensor(posedirs))
+
+ # indices of parents for each joints
+ self.parents = dd['kintree_table'][0].astype(np.int32)
+
+ # LBS weights
+ self.register_buffer('weights', torch.Tensor(undo_chumpy(dd['weights'])))
+
+
+ def _caclulate_bone_lengths_from_J(self, J, betas_logscale):
+ # NEW: calculate bone lengths:
+ all_bone_lengths_list = []
+ for i in range(1, self.parents.shape[0]):
+ bone_vec = J[:, i] - J[:, self.parents[i]]
+ bone_length = torch.sqrt(torch.sum(bone_vec ** 2, axis=1))
+ all_bone_lengths_list.append(bone_length)
+ all_bone_lengths = torch.stack(all_bone_lengths_list)
+ # some bones are pairs, it is enough to take one of the two bones
+ all_bone_length_scales = get_bone_length_scales(self.logscale_part_list, betas_logscale)
+ all_bone_lengths = all_bone_lengths.permute((1,0)) * all_bone_length_scales
+
+ return all_bone_lengths #.permute((1,0))
+
+
+ def caclulate_bone_lengths(self, beta, betas_logscale, shapedirs_sel=None, short=True):
+ nBetas = beta.shape[1]
+
+ # 1. Add shape blend shapes
+ # do we use the original shapedirs or a new set of selected shapedirs?
+ if shapedirs_sel is None:
+ shapedirs_sel = self.shapedirs[:nBetas,:]
+ else:
+ assert shapedirs_sel.shape[0] == nBetas
+ v_shaped = self.v_template + torch.reshape(torch.matmul(beta, shapedirs_sel), [-1, self.size[0], self.size[1]])
+
+ # 2. Infer shape-dependent joint locations.
+ Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor)
+ Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor)
+ Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor)
+ J = torch.stack([Jx, Jy, Jz], dim=2)
+
+ # calculate bone lengths
+ all_bone_lengths = self._caclulate_bone_lengths_from_J(J, betas_logscale)
+ selected_bone_lengths = all_bone_lengths[:, IDXS_BONES_NO_REDUNDANCY]
+
+ if short:
+ return selected_bone_lengths
+ else:
+ return all_bone_lengths
+
+
+
+ def __call__(self, beta, betas_limbs, theta=None, pose=None, trans=None, del_v=None, get_skin=True, keyp_conf='red', get_all_info=False, shapedirs_sel=None):
+ device = beta.device
+
+ betas_logscale = betas_limbs
+ # NEW: allow that rotation is given as rotation matrices instead of axis angle rotation
+ # theta: BSxNJointsx3 or BSx(NJoints*3)
+ # pose: NxNJointsx3x3
+ if (theta is None) and (pose is None):
+ raise ValueError("Either pose (rotation matrices NxNJointsx3x3) or theta (axis angle BSxNJointsx3) must be given")
+ elif (theta is not None) and (pose is not None):
+ raise ValueError("Not both pose (rotation matrices NxNJointsx3x3) and theta (axis angle BSxNJointsx3) can be given")
+
+ if True: # self.use_smal_betas:
+ nBetas = beta.shape[1]
+ else:
+ nBetas = 0
+
+ # 1. Add shape blend shapes
+ # do we use the original shapedirs or a new set of selected shapedirs?
+ if shapedirs_sel is None:
+ shapedirs_sel = self.shapedirs[:nBetas,:]
+ else:
+ assert shapedirs_sel.shape[0] == nBetas
+
+ if nBetas > 0:
+ if del_v is None:
+ v_shaped = self.v_template + torch.reshape(torch.matmul(beta, shapedirs_sel), [-1, self.size[0], self.size[1]])
+ else:
+ v_shaped = self.v_template + del_v + torch.reshape(torch.matmul(beta, shapedirs_sel), [-1, self.size[0], self.size[1]])
+ else:
+ if del_v is None:
+ v_shaped = self.v_template.unsqueeze(0)
+ else:
+ v_shaped = self.v_template + del_v
+
+ # 2. Infer shape-dependent joint locations.
+ Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor)
+ Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor)
+ Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor)
+ J = torch.stack([Jx, Jy, Jz], dim=2)
+
+ # 3. Add pose blend shapes
+ # N x 24 x 3 x 3
+ if pose is None:
+ Rs = torch.reshape( batch_rodrigues(torch.reshape(theta, [-1, 3])), [-1, 35, 3, 3])
+ else:
+ Rs = pose
+ # Ignore global rotation.
+ pose_feature = torch.reshape(Rs[:, 1:, :, :] - torch.eye(3).to(device=device), [-1, 306])
+
+ v_posed = torch.reshape(
+ torch.matmul(pose_feature, self.posedirs),
+ [-1, self.size[0], self.size[1]]) + v_shaped
+
+ #-------------------------
+ # new: add corrections of bone lengths to the template (before hypothetical pose blend shapes!)
+ # see biggs batch_lbs.py
+ betas_scale = torch.exp(betas_logscale @ self.betas_scale_mask.to(betas_logscale.device))
+ scaling_factors = betas_scale.reshape(-1, 35, 3)
+ scale_factors_3x3 = torch.diag_embed(scaling_factors, dim1=-2, dim2=-1)
+
+ # 4. Get the global joint location
+ # self.J_transformed, A = batch_global_rigid_transformation(Rs, J, self.parents)
+ self.J_transformed, A = batch_global_rigid_transformation_biggs(Rs, J, self.parents, scale_factors_3x3, betas_logscale=betas_logscale)
+
+ # 2-BONES. Calculate bone lengths
+ all_bone_lengths = self._caclulate_bone_lengths_from_J(J, betas_logscale)
+ # selected_bone_lengths = all_bone_lengths[:, IDXS_BONES_NO_REDUNDANCY]
+ #-------------------------
+
+ # 5. Do skinning:
+ num_batch = Rs.shape[0]
+
+ weights_t = self.weights.repeat([num_batch, 1])
+ W = torch.reshape(weights_t, [num_batch, -1, 35])
+
+
+ T = torch.reshape(
+ torch.matmul(W, torch.reshape(A, [num_batch, 35, 16])),
+ [num_batch, -1, 4, 4])
+ v_posed_homo = torch.cat(
+ [v_posed, torch.ones([num_batch, v_posed.shape[1], 1]).to(device=device)], 2)
+ v_homo = torch.matmul(T, v_posed_homo.unsqueeze(-1))
+
+ verts = v_homo[:, :, :3, 0]
+
+ if trans is None:
+ trans = torch.zeros((num_batch,3)).to(device=device)
+
+ verts = verts + trans[:,None,:]
+
+ # Get joints:
+ joint_x = torch.matmul(verts[:, :, 0], self.J_regressor)
+ joint_y = torch.matmul(verts[:, :, 1], self.J_regressor)
+ joint_z = torch.matmul(verts[:, :, 2], self.J_regressor)
+ joints = torch.stack([joint_x, joint_y, joint_z], dim=2)
+
+ # New... (see https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py)
+ joints = torch.cat([
+ joints,
+ verts[:, None, 1863], # end_of_nose
+ verts[:, None, 26], # chin
+ verts[:, None, 2124], # right ear tip
+ verts[:, None, 150], # left ear tip
+ verts[:, None, 3055], # left eye
+ verts[:, None, 1097], # right eye
+ ], dim = 1)
+
+ if keyp_conf == 'blue' or keyp_conf == 'dict':
+ # Generate keypoints
+ nLandmarks = KEY_VIDS.shape[0] # 24
+ j3d = torch.zeros((num_batch, nLandmarks, 3)).to(device=device)
+ for j in range(nLandmarks):
+ j3d[:, j,:] = torch.mean(verts[:, KEY_VIDS[j],:], dim=1) # translation is already added to the vertices
+ joints_blue = j3d
+
+ joints_red = joints[:, :-6, :]
+ joints_green = joints[:, CANONICAL_MODEL_JOINTS, :]
+
+ if keyp_conf == 'red':
+ relevant_joints = joints_red
+ elif keyp_conf == 'green':
+ relevant_joints = joints_green
+ elif keyp_conf == 'blue':
+ relevant_joints = joints_blue
+ elif keyp_conf == 'dict':
+ relevant_joints = {'red': joints_red,
+ 'green': joints_green,
+ 'blue': joints_blue}
+ else:
+ raise NotImplementedError
+
+ if get_all_info:
+ return verts, relevant_joints, Rs, all_bone_lengths
+ else:
+ if get_skin:
+ return verts, relevant_joints, Rs # , v_shaped
+ else:
+ return relevant_joints
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/smal_pytorch/utils.py b/src/smal_pytorch/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..11e48a0fe88cf27472c56cb7c9d3359984fd9b9a
--- /dev/null
+++ b/src/smal_pytorch/utils.py
@@ -0,0 +1,13 @@
+import numpy as np
+
+def load_vertex_colors(obj_path):
+ v_colors = []
+ for line in open(obj_path, "r"):
+ if line.startswith('#'): continue
+ values = line.split()
+ if not values: continue
+ if values[0] == 'v':
+ v_colors.append(values[4:7])
+ else:
+ continue
+ return np.asarray(v_colors, dtype=np.float32)
diff --git a/src/stacked_hourglass/__init__.py b/src/stacked_hourglass/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5da308c10ea77f41570a7a0d417c28ad19ae9d2
--- /dev/null
+++ b/src/stacked_hourglass/__init__.py
@@ -0,0 +1,2 @@
+from stacked_hourglass.model import hg1, hg2, hg4, hg8
+from stacked_hourglass.predictor import HumanPosePredictor
diff --git a/src/stacked_hourglass/datasets/__init__.py b/src/stacked_hourglass/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/stacked_hourglass/datasets/imgcrops.py b/src/stacked_hourglass/datasets/imgcrops.py
new file mode 100644
index 0000000000000000000000000000000000000000..89face653c8d6c92fb4bf453a1ae46957ee68dff
--- /dev/null
+++ b/src/stacked_hourglass/datasets/imgcrops.py
@@ -0,0 +1,77 @@
+
+
+import os
+import glob
+import numpy as np
+import torch
+import torch.utils.data as data
+
+import sys
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
+from configs.anipose_data_info import COMPLETE_DATA_INFO
+from stacked_hourglass.utils.imutils import load_image
+from stacked_hourglass.utils.transforms import crop, color_normalize
+from stacked_hourglass.utils.pilutil import imresize
+from stacked_hourglass.utils.imutils import im_to_torch
+from configs.dataset_path_configs import TEST_IMAGE_CROP_ROOT_DIR
+from configs.data_info import COMPLETE_DATA_INFO_24
+
+
+class ImgCrops(data.Dataset):
+ DATA_INFO = COMPLETE_DATA_INFO_24
+ ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16]
+
+ def __init__(self, img_crop_folder='default', image_path=None, is_train=False, inp_res=256, out_res=64, sigma=1,
+ scale_factor=0.25, rot_factor=30, label_type='Gaussian',
+ do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only'):
+ assert is_train == False
+ assert do_augment == 'default' or do_augment == False
+ self.inp_res = inp_res
+ if img_crop_folder == 'default':
+ self.folder_imgs = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'datasets', 'test_image_crops')
+ else:
+ self.folder_imgs = img_crop_folder
+ name_list = glob.glob(os.path.join(self.folder_imgs, '*.png')) + glob.glob(os.path.join(self.folder_imgs, '*.jpg')) + glob.glob(os.path.join(self.folder_imgs, '*.jpeg'))
+ name_list = sorted(name_list)
+ self.test_name_list = [name.split('/')[-1] for name in name_list]
+ print('len(dataset): ' + str(self.__len__()))
+
+ def __getitem__(self, index):
+ img_name = self.test_name_list[index]
+ # load image
+ img_path = os.path.join(self.folder_imgs, img_name)
+ img = load_image(img_path) # CxHxW
+ # prepare image (cropping and color)
+ img_max = max(img.shape[1], img.shape[2])
+ img_padded = torch.zeros((img.shape[0], img_max, img_max))
+ if img_max == img.shape[2]:
+ start = (img_max-img.shape[1])//2
+ img_padded[:, start:start+img.shape[1], :] = img
+ else:
+ start = (img_max-img.shape[2])//2
+ img_padded[:, :, start:start+img.shape[2]] = img
+ img = img_padded
+ img_prep = im_to_torch(imresize(img, [self.inp_res, self.inp_res], interp='bilinear'))
+ inp = color_normalize(img_prep, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev)
+ # add the following fields to make it compatible with stanext, most of them are fake
+ target_dict = {'index': index, 'center' : -2, 'scale' : -2,
+ 'breed_index': -2, 'sim_breed_index': -2,
+ 'ind_dataset': 1}
+ target_dict['pts'] = np.zeros((self.DATA_INFO.n_keyp, 3))
+ target_dict['tpts'] = np.zeros((self.DATA_INFO.n_keyp, 3))
+ target_dict['target_weight'] = np.zeros((self.DATA_INFO.n_keyp, 1))
+ target_dict['silh'] = np.zeros((self.inp_res, self.inp_res))
+ return inp, target_dict
+
+
+ def __len__(self):
+ return len(self.test_name_list)
+
+
+
+
+
+
+
+
+
diff --git a/src/stacked_hourglass/datasets/imgcropslist.py b/src/stacked_hourglass/datasets/imgcropslist.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5c87dbb902995cf02393247f990217dbd2746f7
--- /dev/null
+++ b/src/stacked_hourglass/datasets/imgcropslist.py
@@ -0,0 +1,95 @@
+
+
+import os
+import glob
+import numpy as np
+import math
+import torch
+import torch.utils.data as data
+
+import sys
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
+from configs.anipose_data_info import COMPLETE_DATA_INFO
+from stacked_hourglass.utils.imutils import load_image, im_to_torch
+from stacked_hourglass.utils.transforms import crop, color_normalize
+from stacked_hourglass.utils.pilutil import imresize
+from stacked_hourglass.utils.imutils import im_to_torch
+from configs.data_info import COMPLETE_DATA_INFO_24
+
+
+class ImgCrops(data.Dataset):
+ DATA_INFO = COMPLETE_DATA_INFO_24
+ ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16]
+
+ def __init__(self, image_list, bbox_list=None, inp_res=256, dataset_mode='keyp_only'):
+ # the list contains the images directly, not only their paths
+ self.image_list = image_list
+ self.bbox_list = bbox_list
+ self.inp_res = inp_res
+ self.test_name_list = []
+ for ind in np.arange(0, len(self.image_list)):
+ self.test_name_list.append(str(ind))
+ print('len(dataset): ' + str(self.__len__()))
+
+ def __getitem__(self, index):
+ '''img_name = self.test_name_list[index]
+ # load image
+ img_path = os.path.join(self.folder_imgs, img_name)
+ img = load_image(img_path) # CxHxW'''
+
+ # load image
+ '''img_hwc = self.image_list[index]
+ img = np.rollaxis(img_hwc, 2, 0) '''
+ img = im_to_torch(self.image_list[index])
+
+ # import pdb; pdb.set_trace()
+
+ # try loading bounding box
+ if self.bbox_list is not None:
+ bbox = self.bbox_list[index]
+ bbox_xywh = [bbox[0][0], bbox[0][1], bbox[1][0]-bbox[0][0], bbox[1][1]-bbox[0][1]]
+ bbox_c = [bbox_xywh[0]+0.5*bbox_xywh[2], bbox_xywh[1]+0.5*bbox_xywh[3]]
+ bbox_max = max(bbox_xywh[2], bbox_xywh[3])
+ bbox_diag = math.sqrt(bbox_xywh[2]**2 + bbox_xywh[3]**2)
+ bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200
+ c = torch.Tensor(bbox_c)
+ s = bbox_s
+ img_prep = crop(img, c, s, [self.inp_res, self.inp_res], rot=0)
+
+ else:
+
+ # prepare image (cropping and color)
+ img_max = max(img.shape[1], img.shape[2])
+ img_padded = torch.zeros((img.shape[0], img_max, img_max))
+ if img_max == img.shape[2]:
+ start = (img_max-img.shape[1])//2
+ img_padded[:, start:start+img.shape[1], :] = img
+ else:
+ start = (img_max-img.shape[2])//2
+ img_padded[:, :, start:start+img.shape[2]] = img
+ img = img_padded
+ img_prep = im_to_torch(imresize(img, [self.inp_res, self.inp_res], interp='bilinear'))
+
+ inp = color_normalize(img_prep, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev)
+ # add the following fields to make it compatible with stanext, most of them are fake
+ target_dict = {'index': index, 'center' : -2, 'scale' : -2,
+ 'breed_index': -2, 'sim_breed_index': -2,
+ 'ind_dataset': 1}
+ target_dict['pts'] = np.zeros((self.DATA_INFO.n_keyp, 3))
+ target_dict['tpts'] = np.zeros((self.DATA_INFO.n_keyp, 3))
+ target_dict['target_weight'] = np.zeros((self.DATA_INFO.n_keyp, 1))
+ target_dict['silh'] = np.zeros((self.inp_res, self.inp_res))
+ return inp, target_dict
+
+
+ def __len__(self):
+ return len(self.image_list)
+
+
+
+
+
+
+
+
+
diff --git a/src/stacked_hourglass/datasets/samplers/custom_pair_samplers.py b/src/stacked_hourglass/datasets/samplers/custom_pair_samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bb8a636d1138a58cd2265f931e2c19ef47a9220
--- /dev/null
+++ b/src/stacked_hourglass/datasets/samplers/custom_pair_samplers.py
@@ -0,0 +1,171 @@
+
+import numpy as np
+import random
+import copy
+import time
+import warnings
+
+from torch.utils.data import Sampler
+from torch._six import int_classes as _int_classes
+
+class CustomPairBatchSampler(Sampler):
+ """Wraps another sampler to yield a mini-batch of indices.
+ The structure of this sampler is way to complicated because it is a shorter/simplified version of
+ CustomBatchSampler. The relations between breeds are not relevant for the cvpr 2022 paper, but we kept
+ this structure which we were using for the experiments with clade related losses. ToDo: restructure
+ this sampler.
+ Args:
+ data_sampler_info (dict): a dictionnary, containing information about the dataset and breeds.
+ batch_size (int): Size of mini-batch.
+ """
+
+ def __init__(self, data_sampler_info, batch_size):
+ if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
+ batch_size <= 0:
+ raise ValueError("batch_size should be a positive integer value, "
+ "but got batch_size={}".format(batch_size))
+ assert batch_size%2 == 0
+ self.data_sampler_info = data_sampler_info
+ self.batch_size = batch_size
+ self.n_desired_batches = int(np.floor(len(self.data_sampler_info['name_list']) / batch_size)) # 157
+
+ def get_description(self):
+ description = "\
+ This sampler works only for even batch sizes. \n\
+ It returns pairs of dogs of the same breed"
+ return description
+
+
+ def __iter__(self):
+ breeds_summary = self.data_sampler_info['breeds_summary']
+
+ breed_image_dict_orig = {}
+ for img_name in self.data_sampler_info['name_list']: # ['n02093859-Kerry_blue_terrier/n02093859_913.jpg', ... ]
+ folder_name = img_name.split('/')[0]
+ breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1]
+ if not (breed_name in breed_image_dict_orig):
+ breed_image_dict_orig[breed_name] = [img_name]
+ else:
+ breed_image_dict_orig[breed_name].append(img_name)
+
+ lengths = np.zeros((len(breed_image_dict_orig.values())))
+ for ind, value in enumerate(breed_image_dict_orig.values()):
+ lengths[ind] = len(value)
+
+ sim_matrix_raw = self.data_sampler_info['breeds_sim_martix_raw']
+ sim_matrix_raw[sim_matrix_raw>0].shape # we have 1061 connections
+
+ # from ind_in_sim_mat to breed_name
+ inverse_sim_dict = {}
+ for abbrev, ind in self.data_sampler_info['breeds_sim_abbrev_inds'].items():
+ # breed_name might be None
+ breed = breeds_summary[abbrev]
+ breed_name = breed._name_stanext
+ inverse_sim_dict[ind] = {'abbrev': abbrev,
+ 'breed_name': breed_name}
+
+ # similarity for relevant breeds only:
+ related_breeds_top_orig = {}
+ temp = np.arange(sim_matrix_raw.shape[0])
+ for breed_name, breed_images in breed_image_dict_orig.items():
+ abbrev = self.data_sampler_info['breeds_abbrev_dict'][breed_name]
+ related_breeds = {}
+ if abbrev in self.data_sampler_info['breeds_sim_abbrev_inds'].keys():
+ ind_in_sim_mat = self.data_sampler_info['breeds_sim_abbrev_inds'][abbrev]
+ row = sim_matrix_raw[ind_in_sim_mat, :]
+ rel_inds = temp[row>0]
+ for ind in rel_inds:
+ rel_breed_name = inverse_sim_dict[ind]['breed_name']
+ rel_abbrev = inverse_sim_dict[ind]['abbrev']
+ # does this breed exist in this dataset?
+ if (rel_breed_name is not None) and (rel_breed_name in breed_image_dict_orig.keys()) and not (rel_breed_name==breed_name):
+ related_breeds[rel_breed_name] = row[ind]
+ related_breeds_top_orig[breed_name] = related_breeds
+
+ breed_image_dict = copy.deepcopy(breed_image_dict_orig)
+ related_breeds_top = copy.deepcopy(related_breeds_top_orig)
+
+ # clean the related_breeds_top dict such that it only contains breeds which are available
+ for breed_name, breed_images in breed_image_dict.items():
+ if len(breed_image_dict[breed_name]) < 1:
+ for breed_name_rel in list(related_breeds_top[breed_name].keys()):
+ related_breeds_top[breed_name_rel].pop(breed_name, None)
+ related_breeds_top[breed_name].pop(breed_name_rel, None)
+
+ # 1) build pairs of dogs
+ set_of_breeds_with_at_least_2 = set()
+ for breed_name, breed_images in breed_image_dict.items():
+ if len(breed_images) >= 2:
+ set_of_breeds_with_at_least_2.add(breed_name)
+
+ n_unused_images = len(self.data_sampler_info['name_list'])
+ all_dog_duos = []
+ n_new_duos = 1
+ while n_new_duos > 0:
+ for breed_name, breed_images in breed_image_dict.items():
+ # shuffle image list for this specific breed (this changes the dict)
+ random.shuffle(breed_images)
+ breed_list = list(related_breeds_top.keys())
+ random.shuffle(breed_list)
+ n_new_duos = 0
+ for breed_name in breed_list:
+ if len(breed_image_dict[breed_name]) >= 2:
+ dog_a = breed_image_dict[breed_name].pop()
+ dog_b = breed_image_dict[breed_name].pop()
+ dog_duo = [dog_a, dog_b]
+ all_dog_duos.append({'image_names': dog_duo})
+ # clean the related_breeds_top dict such that it only contains breeds which are still available
+ if len(breed_image_dict[breed_name]) < 1:
+ for breed_name_rel in list(related_breeds_top[breed_name].keys()):
+ related_breeds_top[breed_name_rel].pop(breed_name, None)
+ related_breeds_top[breed_name].pop(breed_name_rel, None)
+ n_new_duos += 1
+ n_unused_images -= 2
+
+ image_name_to_ind = {}
+ for ind_img_name, img_name in enumerate(self.data_sampler_info['name_list']):
+ image_name_to_ind[img_name] = ind_img_name
+
+ # take all images and create the batches
+ n_avail_2 = len(all_dog_duos)
+ all_batches = []
+ ind_in_duos = 0
+ n_imgs_used_twice = 0
+ for ind_b in range(0, self.n_desired_batches):
+ batch_with_image_names = []
+ for ind in range(int(np.floor(self.batch_size / 2))):
+ if ind_in_duos >= n_avail_2:
+ ind_rand = random.randint(0, n_avail_2-1)
+ batch_with_image_names.extend(all_dog_duos[ind_rand]['image_names'])
+ n_imgs_used_twice += 2
+ else:
+ batch_with_image_names.extend(all_dog_duos[ind_in_duos]['image_names'])
+ ind_in_duos += 1
+
+
+ batch_with_inds = []
+ for image_name in batch_with_image_names: # rather a folder than name
+ batch_with_inds.append(image_name_to_ind[image_name])
+
+ all_batches.append(batch_with_inds)
+
+ for batch in all_batches:
+ yield batch
+
+ def __len__(self):
+ # Since we are sampling pairs of dogs and not each breed has an even number of dogs, we can not
+ # guarantee to show each dog exacly once. What we do instead, is returning the same amount of
+ # batches as we would return with a standard sampler which is not based on dog pairs.
+ '''if self.drop_last:
+ return len(self.sampler) // self.batch_size # type: ignore
+ else:
+ return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore'''
+ return self.n_desired_batches
+
+
+
+
+
+
+
+
diff --git a/src/stacked_hourglass/datasets/stanext24.py b/src/stacked_hourglass/datasets/stanext24.py
new file mode 100644
index 0000000000000000000000000000000000000000..e217bf076fb63de5655fc173737ecd2e9803b1e6
--- /dev/null
+++ b/src/stacked_hourglass/datasets/stanext24.py
@@ -0,0 +1,301 @@
+# 24 joints instead of 20!!
+
+
+import gzip
+import json
+import os
+import random
+import math
+import numpy as np
+import torch
+import torch.utils.data as data
+from importlib_resources import open_binary
+from scipy.io import loadmat
+from tabulate import tabulate
+import itertools
+import json
+from scipy import ndimage
+
+from csv import DictReader
+from pycocotools.mask import decode as decode_RLE
+
+import sys
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
+from configs.data_info import COMPLETE_DATA_INFO_24
+from stacked_hourglass.utils.imutils import load_image, draw_labelmap, draw_multiple_labelmaps
+from stacked_hourglass.utils.misc import to_torch
+from stacked_hourglass.utils.transforms import shufflelr, crop, color_normalize, fliplr, transform
+import stacked_hourglass.datasets.utils_stanext as utils_stanext
+from stacked_hourglass.utils.visualization import save_input_image_with_keypoints
+from configs.dog_breeds.dog_breed_class import COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS, SIM_MATRIX_RAW, SIM_ABBREV_INDICES
+from configs.dataset_path_configs import STANEXT_RELATED_DATA_ROOT_DIR
+
+
+class StanExt(data.Dataset):
+ DATA_INFO = COMPLETE_DATA_INFO_24
+
+ # Suggested joints to use for keypoint reprojection error calculations
+ ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16]
+
+ def __init__(self, image_path=None, is_train=True, inp_res=256, out_res=64, sigma=1,
+ scale_factor=0.25, rot_factor=30, label_type='Gaussian',
+ do_augment='default', shorten_dataset_to=None, dataset_mode='keyp_only', V12=None, val_opt='test'):
+ self.V12 = V12
+ self.is_train = is_train # training set or test set
+ if do_augment == 'yes':
+ self.do_augment = True
+ elif do_augment == 'no':
+ self.do_augment = False
+ elif do_augment=='default':
+ if self.is_train:
+ self.do_augment = True
+ else:
+ self.do_augment = False
+ else:
+ raise ValueError
+ self.inp_res = inp_res
+ self.out_res = out_res
+ self.sigma = sigma
+ self.scale_factor = scale_factor
+ self.rot_factor = rot_factor
+ self.label_type = label_type
+ self.dataset_mode = dataset_mode
+ if self.dataset_mode=='complete' or self.dataset_mode=='keyp_and_seg' or self.dataset_mode=='keyp_and_seg_and_partseg':
+ self.calc_seg = True
+ else:
+ self.calc_seg = False
+ self.val_opt = val_opt
+
+ # create train/val split
+ self.img_folder = utils_stanext.get_img_dir(V12=self.V12)
+ self.train_dict, init_test_dict, init_val_dict = utils_stanext.load_stanext_json_as_dict(split_train_test=True, V12=self.V12)
+ self.train_name_list = list(self.train_dict.keys()) # 7004
+ if self.val_opt == 'test':
+ self.test_dict = init_test_dict
+ self.test_name_list = list(self.test_dict.keys())
+ elif self.val_opt == 'val':
+ self.test_dict = init_val_dict
+ self.test_name_list = list(self.test_dict.keys())
+ else:
+ raise NotImplementedError
+
+ # stanext breed dict (contains for each name a stanext specific index)
+ breed_json_path = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'StanExt_breed_dict_v2.json')
+ self.breed_dict = self.get_breed_dict(breed_json_path, create_new_breed_json=False)
+ self.train_name_list = sorted(self.train_name_list)
+ self.test_name_list = sorted(self.test_name_list)
+ random.seed(4)
+ random.shuffle(self.train_name_list)
+ random.shuffle(self.test_name_list)
+ if shorten_dataset_to is not None:
+ # sometimes it is useful to have a smaller set (validation speed, debugging)
+ self.train_name_list = self.train_name_list[0 : min(len(self.train_name_list), shorten_dataset_to)]
+ self.test_name_list = self.test_name_list[0 : min(len(self.test_name_list), shorten_dataset_to)]
+ # special case for debugging: 12 similar images
+ if shorten_dataset_to == 12:
+ my_sample = self.test_name_list[2]
+ for ind in range(0, 12):
+ self.test_name_list[ind] = my_sample
+ print('len(dataset): ' + str(self.__len__()))
+
+ # add results for eyes, whithers and throat as obtained through anipose -> they are used
+ # as pseudo ground truth at training time.
+ self.path_anipose_out_root = os.path.join(STANEXT_RELATED_DATA_ROOT_DIR, 'animalpose_hg8_v0_results_on_StanExt')
+
+
+ def get_data_sampler_info(self):
+ # for custom data sampler
+ if self.is_train:
+ name_list = self.train_name_list
+ else:
+ name_list = self.test_name_list
+ info_dict = {'name_list': name_list,
+ 'stanext_breed_dict': self.breed_dict,
+ 'breeds_abbrev_dict': COMPLETE_ABBREV_DICT,
+ 'breeds_summary': COMPLETE_SUMMARY_BREEDS,
+ 'breeds_sim_martix_raw': SIM_MATRIX_RAW,
+ 'breeds_sim_abbrev_inds': SIM_ABBREV_INDICES
+ }
+ return info_dict
+
+
+ def get_breed_dict(self, breed_json_path, create_new_breed_json=False):
+ if create_new_breed_json:
+ breed_dict = {}
+ breed_index = 0
+ for img_name in self.train_name_list:
+ folder_name = img_name.split('/')[0]
+ breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1]
+ if not (folder_name in breed_dict):
+ breed_dict[folder_name] = {
+ 'breed_name': breed_name,
+ 'index': breed_index}
+ breed_index += 1
+ with open(breed_json_path, 'w', encoding='utf-8') as f: json.dump(breed_dict, f, ensure_ascii=False, indent=4)
+ else:
+ with open(breed_json_path) as json_file: breed_dict = json.load(json_file)
+ return breed_dict
+
+
+ def __getitem__(self, index):
+
+ if self.is_train:
+ name = self.train_name_list[index]
+ data = self.train_dict[name]
+ else:
+ name = self.test_name_list[index]
+ data = self.test_dict[name]
+
+ sf = self.scale_factor
+ rf = self.rot_factor
+
+ img_path = os.path.join(self.img_folder, data['img_path'])
+ try:
+ anipose_res_path = os.path.join(self.path_anipose_out_root, data['img_path'].replace('.jpg', '.json'))
+ with open(anipose_res_path) as f: anipose_data = json.load(f)
+ anipose_thr = 0.2
+ anipose_joints_0to24 = np.asarray(anipose_data['anipose_joints_0to24']).reshape((-1, 3))
+ anipose_joints_0to24_scores = anipose_joints_0to24[:, 2]
+ # anipose_joints_0to24_scores[anipose_joints_0to24_scores>anipose_thr] = 1.0
+ anipose_joints_0to24_scores[anipose_joints_0to24_scores bbox_max = 256
+ # bbox_s = bbox_diag / 200. # diagonal of the boundingbox will be 200
+ bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200
+ c = torch.Tensor(bbox_c)
+ s = bbox_s
+
+ # For single-person pose estimation with a centered/scaled figure
+ nparts = pts.size(0)
+ img = load_image(img_path) # CxHxW
+
+ # segmentation map (we reshape it to 3xHxW, such that we can do the
+ # same transformations as with the image)
+ if self.calc_seg:
+ seg = torch.Tensor(utils_stanext.get_seg_from_entry(data)[None, :, :])
+ seg = torch.cat(3*[seg])
+
+ r = 0
+ do_flip = False
+ if self.do_augment:
+ s = s*torch.randn(1).mul_(sf).add_(1).clamp(1-sf, 1+sf)[0]
+ r = torch.randn(1).mul_(rf).clamp(-2*rf, 2*rf)[0] if random.random() <= 0.6 else 0
+ # Flip
+ if random.random() <= 0.5:
+ do_flip = True
+ img = fliplr(img)
+ if self.calc_seg:
+ seg = fliplr(seg)
+ pts = shufflelr(pts, img.size(2), self.DATA_INFO.hflip_indices)
+ c[0] = img.size(2) - c[0]
+ # Color
+ img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
+ img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
+ img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
+
+ # Prepare image and groundtruth map
+ inp = crop(img, c, s, [self.inp_res, self.inp_res], rot=r)
+ img_border_mask = torch.all(inp > 1.0/256, dim = 0).unsqueeze(0).float() # 1 is foreground
+ inp = color_normalize(inp, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev)
+ if self.calc_seg:
+ seg = crop(seg, c, s, [self.inp_res, self.inp_res], rot=r)
+
+ # Generate ground truth
+ tpts = pts.clone()
+ target_weight = tpts[:, 2].clone().view(nparts, 1)
+
+ target = torch.zeros(nparts, self.out_res, self.out_res)
+ for i in range(nparts):
+ # if tpts[i, 2] > 0: # This is evil!!
+ if tpts[i, 1] > 0:
+ tpts[i, 0:2] = to_torch(transform(tpts[i, 0:2]+1, c, s, [self.out_res, self.out_res], rot=r, as_int=False))
+ target[i], vis = draw_labelmap(target[i], tpts[i]-1, self.sigma, type=self.label_type)
+ target_weight[i, 0] *= vis
+ # NEW:
+ '''target_new, vis_new = draw_multiple_labelmaps((self.out_res, self.out_res), tpts[:, :2]-1, self.sigma, type=self.label_type)
+ target_weight_new = tpts[:, 2].clone().view(nparts, 1) * vis_new
+ target_new[(target_weight_new==0).reshape((-1)), :, :] = 0'''
+
+ # --- Meta info
+ this_breed = self.breed_dict[name.split('/')[0]] # 120
+ # add information about location within breed similarity matrix
+ folder_name = name.split('/')[0]
+ breed_name = folder_name.split(folder_name.split('-')[0] + '-')[1]
+ abbrev = COMPLETE_ABBREV_DICT[breed_name]
+ try:
+ sim_breed_index = COMPLETE_SUMMARY_BREEDS[abbrev]._ind_in_xlsx_matrix
+ except: # some breeds are not in the xlsx file
+ sim_breed_index = -1
+ meta = {'index' : index, 'center' : c, 'scale' : s,
+ 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight,
+ 'breed_index': this_breed['index'], 'sim_breed_index': sim_breed_index,
+ 'ind_dataset': 0} # ind_dataset=0 for stanext or stanexteasy or stanext 2
+ meta2 = {'index' : index, 'center' : c, 'scale' : s,
+ 'pts' : pts, 'tpts' : tpts, 'target_weight': target_weight,
+ 'ind_dataset': 3}
+
+ # return different things depending on dataset_mode
+ if self.dataset_mode=='keyp_only':
+ # save_input_image_with_keypoints(inp, meta['tpts'], out_path='./test_input_stanext.png', ratio_in_out=self.inp_res/self.out_res)
+ return inp, target, meta
+ elif self.dataset_mode=='keyp_and_seg':
+ meta['silh'] = seg[0, :, :]
+ meta['name'] = name
+ return inp, target, meta
+ elif self.dataset_mode=='keyp_and_seg_and_partseg':
+ # partseg is fake! this does only exist such that this dataset can be combined with an other datset that has part segmentations
+ meta2['silh'] = seg[0, :, :]
+ meta2['name'] = name
+ fake_body_part_matrix = torch.ones((3, 256, 256)).long() * (-1)
+ meta2['body_part_matrix'] = fake_body_part_matrix
+ return inp, target, meta2
+ elif self.dataset_mode=='complete':
+ target_dict = meta
+ target_dict['silh'] = seg[0, :, :]
+ # NEW for silhouette loss
+ target_dict['img_border_mask'] = img_border_mask
+ target_dict['has_seg'] = True
+ if target_dict['silh'].sum() < 1:
+ if ((not self.is_train) and self.val_opt == 'test'):
+ raise ValueError
+ elif self.is_train:
+ print('had to replace training image')
+ replacement_index = max(0, index - 1)
+ inp, target_dict = self.__getitem__(replacement_index)
+ else:
+ # There seem to be a few validation images without segmentation
+ # which would lead to nan in iou calculation
+ replacement_index = max(0, index - 1)
+ inp, target_dict = self.__getitem__(replacement_index)
+ return inp, target_dict
+ else:
+ print('sampling error')
+ import pdb; pdb.set_trace()
+ raise ValueError
+
+
+ def __len__(self):
+ if self.is_train:
+ return len(self.train_name_list)
+ else:
+ return len(self.test_name_list)
+
+
diff --git a/src/stacked_hourglass/datasets/utils_stanext.py b/src/stacked_hourglass/datasets/utils_stanext.py
new file mode 100644
index 0000000000000000000000000000000000000000..83da8452f74ff8fb0ca95e2d8a42ba96972f684b
--- /dev/null
+++ b/src/stacked_hourglass/datasets/utils_stanext.py
@@ -0,0 +1,114 @@
+
+import os
+from matplotlib import pyplot as plt
+import glob
+import json
+import numpy as np
+from scipy.io import loadmat
+from csv import DictReader
+from collections import OrderedDict
+from pycocotools.mask import decode as decode_RLE
+
+import sys
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
+from configs.dataset_path_configs import IMG_V12_DIR, JSON_V12_DIR, STAN_V12_TRAIN_LIST_DIR, STAN_V12_VAL_LIST_DIR, STAN_V12_TEST_LIST_DIR
+
+
+def get_img_dir(V12):
+ if V12:
+ return IMG_V12_DIR
+ else:
+ return IMG_DIR
+
+def get_seg_from_entry(entry):
+ """Given a .json entry, returns the binary mask as a numpy array"""
+ rle = {
+ "size": [entry['img_height'], entry['img_width']],
+ "counts": entry['seg']}
+ decoded = decode_RLE(rle)
+ return decoded
+
+def full_animal_visible(seg_data):
+ if seg_data[0, :].sum() == 0 and seg_data[seg_data.shape[0]-1, :].sum() == 0 and seg_data[:, 0].sum() == 0 and seg_data[:, seg_data.shape[1]-1].sum() == 0:
+ return True
+ else:
+ return False
+
+def load_train_and_test_lists(train_list_dir=None , test_list_dir=None):
+ """ returns sets containing names such as 'n02085620-Chihuahua/n02085620_5927.jpg' """
+ # train data
+ train_list_mat = loadmat(train_list_dir)
+ train_list = []
+ for ind in range(0, train_list_mat['file_list'].shape[0]):
+ name = train_list_mat['file_list'][ind, 0][0]
+ train_list.append(name)
+ # test data
+ test_list_mat = loadmat(test_list_dir)
+ test_list = []
+ for ind in range(0, test_list_mat['file_list'].shape[0]):
+ name = test_list_mat['file_list'][ind, 0][0]
+ test_list.append(name)
+ return train_list, test_list
+
+
+
+def _filter_dict(t_list, j_dict, n_kp_min=4):
+ """ should only be used by load_stanext_json_as_dict() """
+ out_dict = {}
+ for sample in t_list:
+ if sample in j_dict.keys():
+ n_kp = np.asarray(j_dict[sample]['joints'])[:, 2].sum()
+ if n_kp >= n_kp_min:
+ out_dict[sample] = j_dict[sample]
+ return out_dict
+
+def load_stanext_json_as_dict(split_train_test=True, V12=True):
+ # load json into memory
+ if V12:
+ with open(JSON_V12_DIR) as infile:
+ json_data = json.load(infile)
+ # with open(JSON_V12_DIR) as infile: json_data = json.load(infile, object_pairs_hook=OrderedDict)
+ else:
+ with open(JSON_DIR) as infile:
+ json_data = json.load(infile)
+ # convert json data to a dictionary of img_path : all_data, for easy lookup
+ json_dict = {i['img_path']: i for i in json_data}
+ if split_train_test:
+ if V12:
+ train_list_numbers = np.load(STAN_V12_TRAIN_LIST_DIR)
+ val_list_numbers = np.load(STAN_V12_VAL_LIST_DIR)
+ test_list_numbers = np.load(STAN_V12_TEST_LIST_DIR)
+ train_list = [json_data[i]['img_path'] for i in train_list_numbers]
+ val_list = [json_data[i]['img_path'] for i in val_list_numbers]
+ test_list = [json_data[i]['img_path'] for i in test_list_numbers]
+ train_dict = _filter_dict(train_list, json_dict, n_kp_min=4)
+ val_dict = _filter_dict(val_list, json_dict, n_kp_min=4)
+ test_dict = _filter_dict(test_list, json_dict, n_kp_min=4)
+ return train_dict, test_dict, val_dict
+ else:
+ train_list, test_list = load_train_and_test_lists(train_list_dir=STAN_ORIG_TRAIN_LIST_DIR , test_list_dir=STAN_ORIG_TEST_LIST_DIR)
+ train_dict = _filter_dict(train_list, json_dict)
+ test_dict = _filter_dict(test_list, json_dict)
+ return train_dict, test_dict, None
+ else:
+ return json_dict
+
+def get_dog(json_dict, name, img_dir=None): # (json_dict, name, img_dir=IMG_DIR)
+ """ takes the name of a dog, and loads in all the relevant information as a dictionary:
+ dict_keys(['img_path', 'img_width', 'img_height', 'joints', 'img_bbox',
+ 'is_multiple_dogs', 'seg', 'img_data', 'seg_data'])
+ img_bbox: [x0, y0, width, height] """
+ data = json_dict[name]
+ # load img
+ img_data = plt.imread(os.path.join(img_dir, data['img_path']))
+ # load seg
+ seg_data = get_seg_from_entry(data)
+ # add to output
+ data['img_data'] = img_data # 0 to 255
+ data['seg_data'] = seg_data # 0: bg, 1: fg
+ return data
+
+
+
+
+
diff --git a/src/stacked_hourglass/model.py b/src/stacked_hourglass/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df09246044e8450efeb0b12f86cb9780a435a60
--- /dev/null
+++ b/src/stacked_hourglass/model.py
@@ -0,0 +1,308 @@
+# Modified from:
+# https://github.com/anibali/pytorch-stacked-hourglass
+# https://github.com/bearpaw/pytorch-pose
+# Hourglass network inserted in the pre-activated Resnet
+# Use lr=0.01 for current version
+# (c) YANG, Wei
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.hub import load_state_dict_from_url
+
+
+__all__ = ['HourglassNet', 'hg']
+
+
+model_urls = {
+ 'hg1': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg1-ce125879.pth',
+ 'hg2': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg2-15e342d9.pth',
+ 'hg8': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg8-90e5d470.pth',
+}
+
+
+class Bottleneck(nn.Module):
+ expansion = 2
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+
+ self.bn1 = nn.BatchNorm2d(inplanes)
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=True)
+ self.bn3 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=True)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.bn1(x)
+ out = self.relu(out)
+ out = self.conv1(out)
+
+ out = self.bn2(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ out = self.bn3(out)
+ out = self.relu(out)
+ out = self.conv3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+
+ return out
+
+
+class Hourglass(nn.Module):
+ def __init__(self, block, num_blocks, planes, depth):
+ super(Hourglass, self).__init__()
+ self.depth = depth
+ self.block = block
+ self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
+
+ def _make_residual(self, block, num_blocks, planes):
+ layers = []
+ for i in range(0, num_blocks):
+ layers.append(block(planes*block.expansion, planes))
+ return nn.Sequential(*layers)
+
+ def _make_hour_glass(self, block, num_blocks, planes, depth):
+ hg = []
+ for i in range(depth):
+ res = []
+ for j in range(3):
+ res.append(self._make_residual(block, num_blocks, planes))
+ if i == 0:
+ res.append(self._make_residual(block, num_blocks, planes))
+ hg.append(nn.ModuleList(res))
+ return nn.ModuleList(hg)
+
+ def _hour_glass_forward(self, n, x):
+ up1 = self.hg[n-1][0](x)
+ low1 = F.max_pool2d(x, 2, stride=2)
+ low1 = self.hg[n-1][1](low1)
+
+ if n > 1:
+ low2 = self._hour_glass_forward(n-1, low1)
+ else:
+ low2 = self.hg[n-1][3](low1)
+ low3 = self.hg[n-1][2](low2)
+ up2 = F.interpolate(low3, scale_factor=2)
+ out = up1 + up2
+ return out
+
+ def forward(self, x):
+ return self._hour_glass_forward(self.depth, x)
+
+
+class HourglassNet(nn.Module):
+ '''Hourglass model from Newell et al ECCV 2016'''
+ def __init__(self, block, num_stacks=2, num_blocks=4, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None):
+ super(HourglassNet, self).__init__()
+
+ self.inplanes = 64
+ self.num_feats = 128
+ self.num_stacks = num_stacks
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=True)
+ self.bn1 = nn.BatchNorm2d(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.layer1 = self._make_residual(block, self.inplanes, 1)
+ self.layer2 = self._make_residual(block, self.inplanes, 1)
+ self.layer3 = self._make_residual(block, self.num_feats, 1)
+ self.maxpool = nn.MaxPool2d(2, stride=2)
+ self.upsample_seg = upsample_seg
+ self.add_partseg = add_partseg
+
+ # build hourglass modules
+ ch = self.num_feats*block.expansion
+ hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
+ for i in range(num_stacks):
+ hg.append(Hourglass(block, num_blocks, self.num_feats, 4))
+ res.append(self._make_residual(block, self.num_feats, num_blocks))
+ fc.append(self._make_fc(ch, ch))
+ score.append(nn.Conv2d(ch, num_classes, kernel_size=1, bias=True))
+ if i < num_stacks-1:
+ fc_.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True))
+ score_.append(nn.Conv2d(num_classes, ch, kernel_size=1, bias=True))
+ self.hg = nn.ModuleList(hg)
+ self.res = nn.ModuleList(res)
+ self.fc = nn.ModuleList(fc)
+ self.score = nn.ModuleList(score)
+ self.fc_ = nn.ModuleList(fc_)
+ self.score_ = nn.ModuleList(score_)
+
+ if self.add_partseg:
+ self.hg_ps = (Hourglass(block, num_blocks, self.num_feats, 4))
+ self.res_ps = (self._make_residual(block, self.num_feats, num_blocks))
+ self.fc_ps = (self._make_fc(ch, ch))
+ self.score_ps = (nn.Conv2d(ch, num_partseg, kernel_size=1, bias=True))
+ self.ups_upsampling_ps = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
+
+
+ if self.upsample_seg:
+ self.ups_upsampling = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
+ self.ups_conv0 = nn.Conv2d(3, 32, kernel_size=7, stride=1, padding=3,
+ bias=True)
+ self.ups_bn1 = nn.BatchNorm2d(32)
+ self.ups_conv1 = nn.Conv2d(32, 16, kernel_size=7, stride=1, padding=3,
+ bias=True)
+ self.ups_bn2 = nn.BatchNorm2d(16+2)
+ self.ups_conv2 = nn.Conv2d(16+2, 16, kernel_size=5, stride=1, padding=2,
+ bias=True)
+ self.ups_bn3 = nn.BatchNorm2d(16)
+ self.ups_conv3 = nn.Conv2d(16, 2, kernel_size=5, stride=1, padding=2,
+ bias=True)
+
+
+
+ def _make_residual(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=True),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def _make_fc(self, inplanes, outplanes):
+ bn = nn.BatchNorm2d(inplanes)
+ conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=True)
+ return nn.Sequential(
+ conv,
+ bn,
+ self.relu,
+ )
+
+ def forward(self, x_in):
+ out = []
+ out_seg = []
+ out_partseg = []
+ x = self.conv1(x_in)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.layer1(x)
+ x = self.maxpool(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+
+ for i in range(self.num_stacks):
+ if i == self.num_stacks - 1:
+ if self.add_partseg:
+ y_ps = self.hg_ps(x)
+ y_ps = self.res_ps(y_ps)
+ y_ps = self.fc_ps(y_ps)
+ score_ps = self.score_ps(y_ps)
+ out_partseg.append(score_ps[:, :, :, :])
+ y = self.hg[i](x)
+ y = self.res[i](y)
+ y = self.fc[i](y)
+ score = self.score[i](y)
+ if self.upsample_seg:
+ out.append(score[:, :-2, :, :])
+ out_seg.append(score[:, -2:, :, :])
+ else:
+ out.append(score)
+ if i < self.num_stacks-1:
+ fc_ = self.fc_[i](y)
+ score_ = self.score_[i](score)
+ x = x + fc_ + score_
+
+ if self.upsample_seg:
+ # PLAN: add a residual to the upsampled version of the segmentation image
+ # upsample predicted segmentation
+ seg_score = score[:, -2:, :, :]
+ seg_score_256 = self.ups_upsampling(seg_score)
+ # prepare input image
+
+ ups_img = self.ups_conv0(x_in)
+
+ ups_img = self.ups_bn1(ups_img)
+ ups_img = self.relu(ups_img)
+ ups_img = self.ups_conv1(ups_img)
+
+ # import pdb; pdb.set_trace()
+
+ ups_conc = torch.cat((seg_score_256, ups_img), 1)
+
+ # ups_conc = self.ups_bn2(ups_conc)
+ ups_conc = self.relu(ups_conc)
+ ups_conc = self.ups_conv2(ups_conc)
+
+ ups_conc = self.ups_bn3(ups_conc)
+ ups_conc = self.relu(ups_conc)
+ correction = self.ups_conv3(ups_conc)
+
+ seg_final = seg_score_256 + correction
+
+ if self.add_partseg:
+ partseg_final = self.ups_upsampling_ps(score_ps)
+ out_dict = {'out_list_kp': out,
+ 'out_list_seg': out,
+ 'seg_final': seg_final,
+ 'out_list_partseg': out_partseg,
+ 'partseg_final': partseg_final
+ }
+ return out_dict
+ else:
+ out_dict = {'out_list_kp': out,
+ 'out_list_seg': out,
+ 'seg_final': seg_final
+ }
+ return out_dict
+
+ return out
+
+
+def hg(**kwargs):
+ model = HourglassNet(Bottleneck, num_stacks=kwargs['num_stacks'], num_blocks=kwargs['num_blocks'],
+ num_classes=kwargs['num_classes'], upsample_seg=kwargs['upsample_seg'],
+ add_partseg=kwargs['add_partseg'], num_partseg=kwargs['num_partseg'])
+ return model
+
+
+def _hg(arch, pretrained, progress, **kwargs):
+ model = hg(**kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+def hg1(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None):
+ return _hg('hg1', pretrained, progress, num_stacks=1, num_blocks=num_blocks,
+ num_classes=num_classes, upsample_seg=upsample_seg,
+ add_partseg=add_partseg, num_partseg=num_partseg)
+
+
+def hg2(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None):
+ return _hg('hg2', pretrained, progress, num_stacks=2, num_blocks=num_blocks,
+ num_classes=num_classes, upsample_seg=upsample_seg,
+ add_partseg=add_partseg, num_partseg=num_partseg)
+
+def hg4(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None):
+ return _hg('hg4', pretrained, progress, num_stacks=4, num_blocks=num_blocks,
+ num_classes=num_classes, upsample_seg=upsample_seg,
+ add_partseg=add_partseg, num_partseg=num_partseg)
+
+def hg8(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None):
+ return _hg('hg8', pretrained, progress, num_stacks=8, num_blocks=num_blocks,
+ num_classes=num_classes, upsample_seg=upsample_seg,
+ add_partseg=add_partseg, num_partseg=num_partseg)
diff --git a/src/stacked_hourglass/predictor.py b/src/stacked_hourglass/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..30be3b4fe816cc33018b61632c4ba120ea66dfc3
--- /dev/null
+++ b/src/stacked_hourglass/predictor.py
@@ -0,0 +1,119 @@
+
+# Modified from:
+# https://github.com/anibali/pytorch-stacked-hourglass
+# https://github.com/bearpaw/pytorch-pose
+
+import torch
+from stacked_hourglass.utils.evaluation import final_preds_untransformed
+from stacked_hourglass.utils.imfit import fit, calculate_fit_contain_output_area
+from stacked_hourglass.utils.transforms import color_normalize, fliplr, flip_back
+
+
+def _check_batched(images):
+ if isinstance(images, (tuple, list)):
+ return True
+ if images.ndimension() == 4:
+ return True
+ return False
+
+
+class HumanPosePredictor:
+ def __init__(self, model, device=None, data_info=None, input_shape=None):
+ """Helper class for predicting 2D human pose joint locations.
+
+ Args:
+ model: The model for generating joint heatmaps.
+ device: The computational device to use for inference.
+ data_info: Specifications of the data (defaults to ``Mpii.DATA_INFO``).
+ input_shape: The input dimensions of the model (height, width).
+ """
+ if device is None:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ device = torch.device(device)
+ model.to(device)
+ self.model = model
+ self.device = device
+
+ if data_info is None:
+ raise ValueError
+ # self.data_info = Mpii.DATA_INFO
+ else:
+ self.data_info = data_info
+
+ # Input shape ordering: H, W
+ if input_shape is None:
+ self.input_shape = (256, 256)
+ elif isinstance(input_shape, int):
+ self.input_shape = (input_shape, input_shape)
+ else:
+ self.input_shape = input_shape
+
+ def do_forward(self, input_tensor):
+ self.model.eval()
+ with torch.no_grad():
+ output = self.model(input_tensor)
+ return output
+
+ def prepare_image(self, image):
+ was_fixed_point = not image.is_floating_point()
+ image = torch.empty_like(image, dtype=torch.float32).copy_(image)
+ if was_fixed_point:
+ image /= 255.0
+ if image.shape[-2:] != self.input_shape:
+ image = fit(image, self.input_shape, fit_mode='contain')
+ image = color_normalize(image, self.data_info.rgb_mean, self.data_info.rgb_stddev)
+ return image
+
+ def estimate_heatmaps(self, images, flip=False):
+ is_batched = _check_batched(images)
+ raw_images = images if is_batched else images.unsqueeze(0)
+ input_tensor = torch.empty((len(raw_images), 3, *self.input_shape),
+ device=self.device, dtype=torch.float32)
+ for i, raw_image in enumerate(raw_images):
+ input_tensor[i] = self.prepare_image(raw_image)
+ heatmaps = self.do_forward(input_tensor)[-1].cpu()
+ if flip:
+ flip_input = fliplr(input_tensor)
+ flip_heatmaps = self.do_forward(flip_input)[-1].cpu()
+ heatmaps += flip_back(flip_heatmaps, self.data_info.hflip_indices)
+ heatmaps /= 2
+ if is_batched:
+ return heatmaps
+ else:
+ return heatmaps[0]
+
+ def estimate_joints(self, images, flip=False):
+ """Estimate human joint locations from input images.
+
+ Images are expected to be centred on a human subject and scaled reasonably.
+
+ Args:
+ images: The images to estimate joint locations for. Can be a single image or a list
+ of images.
+ flip (bool): If set to true, evaluates on flipped versions of the images as well and
+ averages the results.
+
+ Returns:
+ The predicted human joint locations in image pixel space.
+ """
+ is_batched = _check_batched(images)
+ raw_images = images if is_batched else images.unsqueeze(0)
+ heatmaps = self.estimate_heatmaps(raw_images, flip=flip).cpu()
+ # final_preds_untransformed compares the first component of shape with x and second with y
+ # This relates to the image Width, Height (Heatmap has shape Height, Width)
+ coords = final_preds_untransformed(heatmaps, heatmaps.shape[-2:][::-1])
+ # Rescale coords to pixel space of specified images.
+ for i, image in enumerate(raw_images):
+ # When returning to original image space we need to compensate for the fact that we are
+ # used fit_mode='contain' when preparing the images for inference.
+ y_off, x_off, height, width = calculate_fit_contain_output_area(*image.shape[-2:], *self.input_shape)
+ coords[i, :, 1] *= self.input_shape[-2] / heatmaps.shape[-2]
+ coords[i, :, 1] -= y_off
+ coords[i, :, 1] *= image.shape[-2] / height
+ coords[i, :, 0] *= self.input_shape[-1] / heatmaps.shape[-1]
+ coords[i, :, 0] -= x_off
+ coords[i, :, 0] *= image.shape[-1] / width
+ if is_batched:
+ return coords
+ else:
+ return coords[0]
diff --git a/src/stacked_hourglass/utils/__init__.py b/src/stacked_hourglass/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/stacked_hourglass/utils/evaluation.py b/src/stacked_hourglass/utils/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b02a4d804332ae263c4fb005d36cccd967ade029
--- /dev/null
+++ b/src/stacked_hourglass/utils/evaluation.py
@@ -0,0 +1,188 @@
+# Modified from:
+# https://github.com/anibali/pytorch-stacked-hourglass
+# https://github.com/bearpaw/pytorch-pose
+
+import math
+import torch
+from kornia.geometry.subpix import dsnt # kornia 0.4.0
+import torch.nn.functional as F
+from .transforms import transform_preds
+
+__all__ = ['get_preds', 'get_preds_soft', 'calc_dists', 'dist_acc', 'accuracy', 'final_preds_untransformed',
+ 'final_preds', 'AverageMeter']
+
+def get_preds(scores, return_maxval=False):
+ ''' get predictions from score maps in torch Tensor
+ return type: torch.LongTensor
+ '''
+ assert scores.dim() == 4, 'Score maps should be 4-dim'
+ maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2)
+
+ maxval = maxval.view(scores.size(0), scores.size(1), 1)
+ idx = idx.view(scores.size(0), scores.size(1), 1) + 1
+
+ preds = idx.repeat(1, 1, 2).float()
+
+ preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1
+ preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(3)) + 1
+
+ pred_mask = maxval.gt(0).repeat(1, 1, 2).float() # values > 0
+ preds *= pred_mask
+ if return_maxval:
+ return preds, maxval
+ else:
+ return preds
+
+
+def get_preds_soft(scores, return_maxval=False, norm_coords=False, norm_and_unnorm_coords=False):
+ ''' get predictions from score maps in torch Tensor
+ predictions are made assuming a logit output map
+ return type: torch.LongTensor
+ '''
+
+ # New: work on logit predictions
+ scores_norm = dsnt.spatial_softmax2d(scores, temperature=torch.tensor(1))
+ # maxval_norm, idx_norm = torch.max(scores_norm.view(scores.size(0), scores.size(1), -1), 2)
+ # from unnormalized to normalized see:
+ # from -1to1 to 0to64
+ # see https://github.com/kornia/kornia/blob/b9ffe7efcba7399daeeb8028f10c22941b55d32d/kornia/utils/grid.py#L7 (line 40)
+ # xs = (xs / (width - 1) - 0.5) * 2
+ # ys = (ys / (height - 1) - 0.5) * 2
+
+ device = scores.device
+
+ if return_maxval:
+ preds_normalized = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=True)
+ # grid_sample(input, grid, mode='bilinear', padding_mode='zeros')
+ gs_input_single = scores_norm.reshape((-1, 1, scores_norm.shape[2], scores_norm.shape[3])) # (120, 1, 64, 64)
+ gs_input = scores_norm.reshape((-1, 1, scores_norm.shape[2], scores_norm.shape[3])) # (120, 1, 64, 64)
+
+ half_pad = 2
+ gs_input_single_padded = F.pad(input=gs_input_single, pad=(half_pad, half_pad, half_pad, half_pad, 0, 0, 0, 0), mode='constant', value=0)
+ gs_input_all = torch.zeros((gs_input_single.shape[0], 9, gs_input_single.shape[2], gs_input_single.shape[3])).to(device)
+ ind_tot = 0
+ for ind0 in [-1, 0, 1]:
+ for ind1 in [-1, 0, 1]:
+ gs_input_all[:, ind_tot, :, :] = gs_input_single_padded[:, 0, half_pad+ind0:-half_pad+ind0, half_pad+ind1:-half_pad+ind1]
+ ind_tot +=1
+
+ gs_grid = preds_normalized.reshape((-1, 2))[:, None, None, :] # (120, 1, 1, 2)
+ gs_output_all = F.grid_sample(gs_input_all, gs_grid, mode='nearest', padding_mode='zeros', align_corners=True).reshape((gs_input_all.shape[0], gs_input_all.shape[1], 1))
+ gs_output = gs_output_all.sum(axis=1)
+ # scores_norm[0, :, :, :].max(axis=2)[0].max(axis=1)[0]
+ # gs_output[0, :, 0]
+ gs_output_resh = gs_output.reshape((scores_norm.shape[0], scores_norm.shape[1], 1))
+
+ if norm_and_unnorm_coords:
+ preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1
+ return preds, preds_normalized, gs_output_resh
+ elif norm_coords:
+ return preds_normalized, gs_output_resh
+ else:
+ preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1
+ return preds, gs_output_resh
+ else:
+ if norm_coords:
+ preds_normalized = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=True)
+ return preds_normalized
+ else:
+ preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1
+ return preds
+
+
+def calc_dists(preds, target, normalize):
+ preds = preds.float()
+ target = target.float()
+ dists = torch.zeros(preds.size(1), preds.size(0))
+ for n in range(preds.size(0)):
+ for c in range(preds.size(1)):
+ if target[n,c,0] > 1 and target[n, c, 1] > 1:
+ dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n]
+ else:
+ dists[c, n] = -1
+ return dists
+
+def dist_acc(dist, thr=0.5):
+ ''' Return percentage below threshold while ignoring values with a -1 '''
+ dist = dist[dist != -1]
+ if len(dist) > 0:
+ return 1.0 * (dist < thr).sum().item() / len(dist)
+ else:
+ return -1
+
+def accuracy(output, target, idxs=None, thr=0.5):
+ ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations
+ First value to be returned is average accuracy across 'idxs', followed by individual accuracies
+ '''
+ if idxs is None:
+ idxs = list(range(target.shape[-3]))
+ preds = get_preds_soft(output) # get_preds(output)
+ gts = get_preds(target)
+ norm = torch.ones(preds.size(0))*output.size(3)/10
+ dists = calc_dists(preds, gts, norm)
+
+ acc = torch.zeros(len(idxs)+1)
+ avg_acc = 0
+ cnt = 0
+
+ for i in range(len(idxs)):
+ acc[i+1] = dist_acc(dists[idxs[i]], thr=thr)
+ if acc[i+1] >= 0:
+ avg_acc = avg_acc + acc[i+1]
+ cnt += 1
+
+ if cnt != 0:
+ acc[0] = avg_acc / cnt
+ return acc
+
+def final_preds_untransformed(output, res):
+ coords = get_preds_soft(output) # get_preds(output) # float type
+
+ # pose-processing
+ for n in range(coords.size(0)):
+ for p in range(coords.size(1)):
+ hm = output[n][p]
+ px = int(math.floor(coords[n][p][0]))
+ py = int(math.floor(coords[n][p][1]))
+ if px > 1 and px < res[0] and py > 1 and py < res[1]:
+ diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]])
+ coords[n][p] += diff.sign() * .25
+ coords += 0.5
+
+ if coords.dim() < 3:
+ coords = coords.unsqueeze(0)
+
+ coords -= 1 # Convert from 1-based to 0-based coordinates
+
+ return coords
+
+def final_preds(output, center, scale, res):
+ coords = final_preds_untransformed(output, res)
+ preds = coords.clone()
+
+ # Transform back
+ for i in range(coords.size(0)):
+ preds[i] = transform_preds(coords[i], center[i], scale[i], res)
+
+ if preds.dim() < 3:
+ preds = preds.unsqueeze(0)
+
+ return preds
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
diff --git a/src/stacked_hourglass/utils/finetune.py b/src/stacked_hourglass/utils/finetune.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7990b26a90e824f02141d7908907679f544f98c
--- /dev/null
+++ b/src/stacked_hourglass/utils/finetune.py
@@ -0,0 +1,39 @@
+# Modified from:
+# https://github.com/anibali/pytorch-stacked-hourglass
+# https://github.com/bearpaw/pytorch-pose
+
+import torch
+from torch.nn import Conv2d, ModuleList
+
+
+def change_hg_outputs(model, indices):
+ """Change the output classes of the model.
+
+ Args:
+ model: The model to modify.
+ indices: An array of indices describing the new model outputs. For example, [3, 4, None]
+ will modify the model to have 3 outputs, the first two of which have parameters
+ copied from the fourth and fifth outputs of the original model.
+ """
+ with torch.no_grad():
+ new_n_outputs = len(indices)
+ new_score = ModuleList()
+ for conv in model.score:
+ new_conv = Conv2d(conv.in_channels, new_n_outputs, conv.kernel_size, conv.stride)
+ new_conv = new_conv.to(conv.weight.device, conv.weight.dtype)
+ for i, index in enumerate(indices):
+ if index is not None:
+ new_conv.weight[i] = conv.weight[index]
+ new_conv.bias[i] = conv.bias[index]
+ new_score.append(new_conv)
+ model.score = new_score
+ new_score_ = ModuleList()
+ for conv in model.score_:
+ new_conv = Conv2d(new_n_outputs, conv.out_channels, conv.kernel_size, conv.stride)
+ new_conv = new_conv.to(conv.weight.device, conv.weight.dtype)
+ for i, index in enumerate(indices):
+ if index is not None:
+ new_conv.weight[:, i] = conv.weight[:, index]
+ new_conv.bias = conv.bias
+ new_score_.append(new_conv)
+ model.score_ = new_score_
diff --git a/src/stacked_hourglass/utils/imfit.py b/src/stacked_hourglass/utils/imfit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee0d2e131bf3c1bd2e0c740d9c8cfd9d847f523d
--- /dev/null
+++ b/src/stacked_hourglass/utils/imfit.py
@@ -0,0 +1,144 @@
+# Modified from:
+# https://github.com/anibali/pytorch-stacked-hourglass
+# https://github.com/bearpaw/pytorch-pose
+
+import torch
+from torch.nn.functional import interpolate
+
+
+def _resize(tensor, size, mode='bilinear'):
+ """Resize the image.
+
+ Args:
+ tensor (torch.Tensor): The image tensor to be resized.
+ size (tuple of int): Size of the resized image (height, width).
+ mode (str): The pixel sampling interpolation mode to be used.
+
+ Returns:
+ Tensor: The resized image tensor.
+ """
+ assert len(size) == 2
+
+ # If the tensor is already the desired size, return it immediately.
+ if tensor.shape[-2] == size[0] and tensor.shape[-1] == size[1]:
+ return tensor
+
+ if not tensor.is_floating_point():
+ dtype = tensor.dtype
+ tensor = tensor.to(torch.float32)
+ tensor = _resize(tensor, size, mode)
+ return tensor.to(dtype)
+
+ out_shape = (*tensor.shape[:-2], *size)
+ if tensor.ndimension() < 3:
+ raise Exception('tensor must be at least 2D')
+ elif tensor.ndimension() == 3:
+ tensor = tensor.unsqueeze(0)
+ elif tensor.ndimension() > 4:
+ tensor = tensor.view(-1, *tensor.shape[-3:])
+ align_corners = None
+ if mode in {'linear', 'bilinear', 'trilinear'}:
+ align_corners = False
+ resized = interpolate(tensor, size=size, mode=mode, align_corners=align_corners)
+ return resized.view(*out_shape)
+
+
+def _crop(tensor, t, l, h, w, padding_mode='constant', fill=0):
+ """Crop the image, padding out-of-bounds regions.
+
+ Args:
+ tensor (torch.Tensor): The image tensor to be cropped.
+ t (int): Top pixel coordinate.
+ l (int): Left pixel coordinate.
+ h (int): Height of the cropped image.
+ w (int): Width of the cropped image.
+ padding_mode (str): Padding mode (currently "constant" is the only valid option).
+ fill (float): Fill value to use with constant padding.
+
+ Returns:
+ Tensor: The cropped image tensor.
+ """
+ # If the _crop region is wholly within the image, simply narrow the tensor.
+ if t >= 0 and l >= 0 and t + h <= tensor.size(-2) and l + w <= tensor.size(-1):
+ return tensor[..., t:t+h, l:l+w]
+
+ if padding_mode == 'constant':
+ result = torch.full((*tensor.size()[:-2], h, w), fill,
+ device=tensor.device, dtype=tensor.dtype)
+ else:
+ raise Exception('_crop only supports "constant" padding currently.')
+
+ sx1 = l
+ sy1 = t
+ sx2 = l + w
+ sy2 = t + h
+ dx1 = 0
+ dy1 = 0
+
+ if sx1 < 0:
+ dx1 = -sx1
+ w += sx1
+ sx1 = 0
+
+ if sy1 < 0:
+ dy1 = -sy1
+ h += sy1
+ sy1 = 0
+
+ if sx2 >= tensor.size(-1):
+ w -= sx2 - tensor.size(-1)
+
+ if sy2 >= tensor.size(-2):
+ h -= sy2 - tensor.size(-2)
+
+ # Copy the in-bounds sub-area of the _crop region into the result tensor.
+ if h > 0 and w > 0:
+ src = tensor.narrow(-2, sy1, h).narrow(-1, sx1, w)
+ dst = result.narrow(-2, dy1, h).narrow(-1, dx1, w)
+ dst.copy_(src)
+
+ return result
+
+
+def calculate_fit_contain_output_area(in_height, in_width, out_height, out_width):
+ ih, iw = in_height, in_width
+ k = min(out_width / iw, out_height / ih)
+ oh = round(k * ih)
+ ow = round(k * iw)
+ y_off = (out_height - oh) // 2
+ x_off = (out_width - ow) // 2
+ return y_off, x_off, oh, ow
+
+
+def fit(tensor, size, fit_mode='cover', resize_mode='bilinear', *, fill=0):
+ """Fit the image within the given spatial dimensions.
+
+ Args:
+ tensor (torch.Tensor): The image tensor to be fit.
+ size (tuple of int): Size of the output (height, width).
+ fit_mode (str): 'fill', 'contain', or 'cover'. These behave in the same way as CSS's
+ `object-fit` property.
+ fill (float): padding value (only applicable in 'contain' mode).
+
+ Returns:
+ Tensor: The resized image tensor.
+ """
+ if fit_mode == 'fill':
+ return _resize(tensor, size, mode=resize_mode)
+ elif fit_mode == 'contain':
+ y_off, x_off, oh, ow = calculate_fit_contain_output_area(*tensor.shape[-2:], *size)
+ resized = _resize(tensor, (oh, ow), mode=resize_mode)
+ result = tensor.new_full((*tensor.size()[:-2], *size), fill)
+ result[..., y_off:y_off + oh, x_off:x_off + ow] = resized
+ return result
+ elif fit_mode == 'cover':
+ ih, iw = tensor.shape[-2:]
+ k = max(size[-1] / iw, size[-2] / ih)
+ oh = round(k * ih)
+ ow = round(k * iw)
+ resized = _resize(tensor, (oh, ow), mode=resize_mode)
+ y_trim = (oh - size[-2]) // 2
+ x_trim = (ow - size[-1]) // 2
+ result = _crop(resized, y_trim, x_trim, size[-2], size[-1])
+ return result
+ raise ValueError('Invalid fit_mode: ' + repr(fit_mode))
diff --git a/src/stacked_hourglass/utils/imutils.py b/src/stacked_hourglass/utils/imutils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5540728cc9f85e55b560308417c3b77d9c678a13
--- /dev/null
+++ b/src/stacked_hourglass/utils/imutils.py
@@ -0,0 +1,125 @@
+# Modified from:
+# https://github.com/anibali/pytorch-stacked-hourglass
+# https://github.com/bearpaw/pytorch-pose
+
+import numpy as np
+
+from .misc import to_numpy, to_torch
+from .pilutil import imread, imresize
+from kornia.geometry.subpix import dsnt
+import torch
+
+def im_to_numpy(img):
+ img = to_numpy(img)
+ img = np.transpose(img, (1, 2, 0)) # H*W*C
+ return img
+
+def im_to_torch(img):
+ img = np.transpose(img, (2, 0, 1)) # C*H*W
+ img = to_torch(img).float()
+ if img.max() > 1:
+ img /= 255
+ return img
+
+def load_image(img_path):
+ # H x W x C => C x H x W
+ return im_to_torch(imread(img_path, mode='RGB'))
+
+# =============================================================================
+# Helpful functions generating groundtruth labelmap
+# =============================================================================
+
+def gaussian(shape=(7,7),sigma=1):
+ """
+ 2D gaussian mask - should give the same result as MATLAB's
+ fspecial('gaussian',[shape],[sigma])
+ """
+ m,n = [(ss-1.)/2. for ss in shape]
+ y,x = np.ogrid[-m:m+1,-n:n+1]
+ h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
+ h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
+ return to_torch(h).float()
+
+def draw_labelmap_orig(img, pt, sigma, type='Gaussian'):
+ # Draw a 2D gaussian
+ # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
+ # maximum value of the gaussian is 1
+ img = to_numpy(img)
+
+ # Check that any part of the gaussian is in-bounds
+ ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
+ br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
+ if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
+ br[0] < 0 or br[1] < 0):
+ # If not, just return the image as is
+ return to_torch(img), 0
+
+ # Generate gaussian
+ size = 6 * sigma + 1
+ x = np.arange(0, size, 1, float)
+ y = x[:, np.newaxis]
+ x0 = y0 = size // 2
+ # The gaussian is not normalized, we want the center value to equal 1
+ if type == 'Gaussian':
+ g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
+ elif type == 'Cauchy':
+ g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
+
+ # Usable gaussian range
+ g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
+ g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
+ # Image range
+ img_x = max(0, ul[0]), min(br[0], img.shape[1])
+ img_y = max(0, ul[1]), min(br[1], img.shape[0])
+
+ img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
+
+ return to_torch(img), 1
+
+
+
+def draw_labelmap(img, pt, sigma, type='Gaussian'):
+ # Draw a 2D gaussian
+ # real probability distribution: the sum of all values is 1
+ img = to_numpy(img)
+ if not type == 'Gaussian':
+ raise NotImplementedError
+
+ # Check that any part of the gaussian is in-bounds
+ ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
+ br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
+ if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
+ br[0] < 0 or br[1] < 0):
+ # If not, just return the image as is
+ return to_torch(img), 0
+
+ # Generate gaussian
+ # img_new = dsnt.render_gaussian2d(mean=torch.tensor([[-1, 0]]).float(), std=torch.tensor([[sigma, sigma]]).float(), size=(img.shape[0], img.shape[1]), normalized_coordinates=False)
+ img_new = dsnt.render_gaussian2d(mean=torch.tensor([[pt[0], pt[1]]]).float(), \
+ std=torch.tensor([[sigma, sigma]]).float(), \
+ size=(img.shape[0], img.shape[1]), \
+ normalized_coordinates=False)
+ img_new = img_new[0, :, :] # this is a torch image
+ return img_new, 1
+
+
+def draw_multiple_labelmaps(out_res, pts, sigma, type='Gaussian'):
+ # Draw a 2D gaussian
+ # real probability distribution: the sum of all values is 1
+ if not type == 'Gaussian':
+ raise NotImplementedError
+
+ # Generate gaussians
+ n_pts = pts.shape[0]
+ imgs_new = dsnt.render_gaussian2d(mean=pts[:, :2], \
+ std=torch.tensor([[sigma, sigma]]).float().repeat((n_pts, 1)), \
+ size=(out_res[0], out_res[1]), \
+ normalized_coordinates=False) # shape: (n_pts, out_res[0], out_res[1])
+
+ visibility_orig = imgs_new.sum(axis=2).sum(axis=1) # shape: (n_pts)
+ visibility = torch.zeros((n_pts, 1), dtype=torch.float32)
+ visibility[visibility_orig>=0.99999] = 1.0
+
+ # import pdb; pdb.set_trace()
+
+ return imgs_new, visibility.int()
\ No newline at end of file
diff --git a/src/stacked_hourglass/utils/logger.py b/src/stacked_hourglass/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e42823a88ae20117fc5aa191f569491c102b1f3
--- /dev/null
+++ b/src/stacked_hourglass/utils/logger.py
@@ -0,0 +1,73 @@
+
+# Modified from:
+# https://github.com/anibali/pytorch-stacked-hourglass
+# https://github.com/bearpaw/pytorch-pose
+
+import numpy as np
+
+__all__ = ['Logger']
+
+
+class Logger:
+ """Log training metrics to a file."""
+ def __init__(self, fpath, resume=False):
+ if resume: ############################################################################
+ # Read header names and previously logged values.
+ with open(fpath, 'r') as f:
+ header_line = f.readline()
+ self.names = header_line.rstrip().split('\t')
+ self.numbers = {}
+ for _, name in enumerate(self.names):
+ self.numbers[name] = []
+ for numbers in f:
+ numbers = numbers.rstrip().split('\t')
+ for i in range(0, len(numbers)):
+ self.numbers[self.names[i]].append(float(numbers[i]))
+
+ self.file = open(fpath, 'a')
+ self.header_written = True
+ else:
+ self.file = open(fpath, 'w')
+ self.header_written = False
+
+ def _write_line(self, field_values):
+ self.file.write('\t'.join(field_values) + '\n')
+ self.file.flush()
+
+ def set_names(self, names):
+ """Set field names and write log header line."""
+ assert not self.header_written, 'Log header has already been written'
+ self.names = names
+ self.numbers = {name: [] for name in self.names}
+ self._write_line(self.names)
+ self.header_written = True
+
+ def append(self, numbers):
+ """Append values to the log."""
+ assert self.header_written, 'Log header has not been written yet (use `set_names`)'
+ assert len(self.names) == len(numbers), 'Numbers do not match names'
+ for index, num in enumerate(numbers):
+ self.numbers[self.names[index]].append(num)
+ self._write_line(['{0:.6f}'.format(num) for num in numbers])
+
+ def plot(self, ax, names=None):
+ """Plot logged metrics on a set of Matplotlib axes."""
+ names = self.names if names == None else names
+ for name in names:
+ values = self.numbers[name]
+ ax.plot(np.arange(len(values)), np.asarray(values))
+ ax.grid(True)
+ ax.legend(names, loc='best')
+
+ def plot_to_file(self, fpath, names=None, dpi=150):
+ """Plot logged metrics and save the resulting figure to a file."""
+ import matplotlib.pyplot as plt
+ fig = plt.figure(dpi=dpi)
+ ax = fig.subplots()
+ self.plot(ax, names)
+ fig.savefig(fpath)
+ plt.close(fig)
+ del ax, fig
+
+ def close(self):
+ self.file.close()
diff --git a/src/stacked_hourglass/utils/misc.py b/src/stacked_hourglass/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..d754c55dc2206bbb2a5cabf18c4017b5c1ee3d04
--- /dev/null
+++ b/src/stacked_hourglass/utils/misc.py
@@ -0,0 +1,56 @@
+# Modified from:
+# https://github.com/anibali/pytorch-stacked-hourglass
+# https://github.com/bearpaw/pytorch-pose
+
+import os
+import shutil
+
+import scipy.io
+import torch
+
+
+def to_numpy(tensor):
+ if torch.is_tensor(tensor):
+ return tensor.detach().cpu().numpy()
+ elif type(tensor).__module__ != 'numpy':
+ raise ValueError("Cannot convert {} to numpy array"
+ .format(type(tensor)))
+ return tensor
+
+
+def to_torch(ndarray):
+ if type(ndarray).__module__ == 'numpy':
+ return torch.from_numpy(ndarray)
+ elif not torch.is_tensor(ndarray):
+ raise ValueError("Cannot convert {} to torch tensor"
+ .format(type(ndarray)))
+ return ndarray
+
+
+def save_checkpoint(state, preds, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar', snapshot=None):
+ preds = to_numpy(preds)
+ filepath = os.path.join(checkpoint, filename)
+ torch.save(state, filepath)
+ scipy.io.savemat(os.path.join(checkpoint, 'preds.mat'), mdict={'preds' : preds})
+
+ if snapshot and state['epoch'] % snapshot == 0:
+ shutil.copyfile(filepath, os.path.join(checkpoint, 'checkpoint_{}.pth.tar'.format(state['epoch'])))
+
+ if is_best:
+ shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))
+ scipy.io.savemat(os.path.join(checkpoint, 'preds_best.mat'), mdict={'preds' : preds})
+
+
+def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'):
+ preds = to_numpy(preds)
+ filepath = os.path.join(checkpoint, filename)
+ scipy.io.savemat(filepath, mdict={'preds' : preds})
+
+
+def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
+ """Sets the learning rate to the initial LR decayed by schedule"""
+ if epoch in schedule:
+ lr *= gamma
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+ return lr
diff --git a/src/stacked_hourglass/utils/pilutil.py b/src/stacked_hourglass/utils/pilutil.py
new file mode 100644
index 0000000000000000000000000000000000000000..4306a31e76581cf9a7dd9901b88be1a2df2a75f0
--- /dev/null
+++ b/src/stacked_hourglass/utils/pilutil.py
@@ -0,0 +1,509 @@
+"""
+A collection of image utilities using the Python Imaging Library (PIL).
+"""
+
+# Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following
+# disclaimer in the documentation and/or other materials provided
+# with the distribution.
+#
+# 3. Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from __future__ import division, print_function, absolute_import
+
+import numpy
+from PIL import Image
+from numpy import (amin, amax, ravel, asarray, arange, ones, newaxis,
+ transpose, iscomplexobj, uint8, issubdtype, array)
+
+if not hasattr(Image, 'frombytes'):
+ Image.frombytes = Image.fromstring
+
+__all__ = ['fromimage', 'toimage', 'imsave', 'imread', 'bytescale',
+ 'imrotate', 'imresize']
+
+
+def bytescale(data, cmin=None, cmax=None, high=255, low=0):
+ """
+ Byte scales an array (image).
+
+ Byte scaling means converting the input image to uint8 dtype and scaling
+ the range to ``(low, high)`` (default 0-255).
+ If the input image already has dtype uint8, no scaling is done.
+
+ This function is only available if Python Imaging Library (PIL) is installed.
+
+ Parameters
+ ----------
+ data : ndarray
+ PIL image data array.
+ cmin : scalar, optional
+ Bias scaling of small values. Default is ``data.min()``.
+ cmax : scalar, optional
+ Bias scaling of large values. Default is ``data.max()``.
+ high : scalar, optional
+ Scale max value to `high`. Default is 255.
+ low : scalar, optional
+ Scale min value to `low`. Default is 0.
+
+ Returns
+ -------
+ img_array : uint8 ndarray
+ The byte-scaled array.
+
+ Examples
+ --------
+ >>> img = numpy.array([[ 91.06794177, 3.39058326, 84.4221549 ],
+ ... [ 73.88003259, 80.91433048, 4.88878881],
+ ... [ 51.53875334, 34.45808177, 27.5873488 ]])
+ >>> bytescale(img)
+ array([[255, 0, 236],
+ [205, 225, 4],
+ [140, 90, 70]], dtype=uint8)
+ >>> bytescale(img, high=200, low=100)
+ array([[200, 100, 192],
+ [180, 188, 102],
+ [155, 135, 128]], dtype=uint8)
+ >>> bytescale(img, cmin=0, cmax=255)
+ array([[91, 3, 84],
+ [74, 81, 5],
+ [52, 34, 28]], dtype=uint8)
+
+ """
+ if data.dtype == uint8:
+ return data
+
+ if high > 255:
+ raise ValueError("`high` should be less than or equal to 255.")
+ if low < 0:
+ raise ValueError("`low` should be greater than or equal to 0.")
+ if high < low:
+ raise ValueError("`high` should be greater than or equal to `low`.")
+
+ if cmin is None:
+ cmin = data.min()
+ if cmax is None:
+ cmax = data.max()
+
+ cscale = cmax - cmin
+ if cscale < 0:
+ raise ValueError("`cmax` should be larger than `cmin`.")
+ elif cscale == 0:
+ cscale = 1
+
+ scale = float(high - low) / cscale
+ bytedata = (data - cmin) * scale + low
+ return (bytedata.clip(low, high) + 0.5).astype(uint8)
+
+
+def imread(name, flatten=False, mode=None):
+ """
+ Read an image from a file as an array.
+
+ This function is only available if Python Imaging Library (PIL) is installed.
+
+ Parameters
+ ----------
+ name : str or file object
+ The file name or file object to be read.
+ flatten : bool, optional
+ If True, flattens the color layers into a single gray-scale layer.
+ mode : str, optional
+ Mode to convert image to, e.g. ``'RGB'``. See the Notes for more
+ details.
+
+ Returns
+ -------
+ imread : ndarray
+ The array obtained by reading the image.
+
+ Notes
+ -----
+ `imread` uses the Python Imaging Library (PIL) to read an image.
+ The following notes are from the PIL documentation.
+
+ `mode` can be one of the following strings:
+
+ * 'L' (8-bit pixels, black and white)
+ * 'P' (8-bit pixels, mapped to any other mode using a color palette)
+ * 'RGB' (3x8-bit pixels, true color)
+ * 'RGBA' (4x8-bit pixels, true color with transparency mask)
+ * 'CMYK' (4x8-bit pixels, color separation)
+ * 'YCbCr' (3x8-bit pixels, color video format)
+ * 'I' (32-bit signed integer pixels)
+ * 'F' (32-bit floating point pixels)
+
+ PIL also provides limited support for a few special modes, including
+ 'LA' ('L' with alpha), 'RGBX' (true color with padding) and 'RGBa'
+ (true color with premultiplied alpha).
+
+ When translating a color image to black and white (mode 'L', 'I' or
+ 'F'), the library uses the ITU-R 601-2 luma transform::
+
+ L = R * 299/1000 + G * 587/1000 + B * 114/1000
+
+ When `flatten` is True, the image is converted using mode 'F'.
+ When `mode` is not None and `flatten` is True, the image is first
+ converted according to `mode`, and the result is then flattened using
+ mode 'F'.
+
+ """
+
+ im = Image.open(name)
+ return fromimage(im, flatten=flatten, mode=mode)
+
+
+def imsave(name, arr, format=None):
+ """
+ Save an array as an image.
+
+ This function is only available if Python Imaging Library (PIL) is installed.
+
+ .. warning::
+
+ This function uses `bytescale` under the hood to rescale images to use
+ the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``.
+ It will also cast data for 2-D images to ``uint32`` for ``mode=None``
+ (which is the default).
+
+ Parameters
+ ----------
+ name : str or file object
+ Output file name or file object.
+ arr : ndarray, MxN or MxNx3 or MxNx4
+ Array containing image values. If the shape is ``MxN``, the array
+ represents a grey-level image. Shape ``MxNx3`` stores the red, green
+ and blue bands along the last dimension. An alpha layer may be
+ included, specified as the last colour band of an ``MxNx4`` array.
+ format : str
+ Image format. If omitted, the format to use is determined from the
+ file name extension. If a file object was used instead of a file name,
+ this parameter should always be used.
+
+ Examples
+ --------
+ Construct an array of gradient intensity values and save to file:
+
+ >>> x = numpy.zeros((255, 255), dtype=numpy.uint8)
+ >>> x[:] = numpy.arange(255)
+ >>> imsave('gradient.png', x)
+
+ Construct an array with three colour bands (R, G, B) and store to file:
+
+ >>> rgb = numpy.zeros((255, 255, 3), dtype=numpy.uint8)
+ >>> rgb[..., 0] = numpy.arange(255)
+ >>> rgb[..., 1] = 55
+ >>> rgb[..., 2] = 1 - numpy.arange(255)
+ >>> imsave('rgb_gradient.png', rgb)
+
+ """
+ im = toimage(arr, channel_axis=2)
+ if format is None:
+ im.save(name)
+ else:
+ im.save(name, format)
+ return
+
+
+def fromimage(im, flatten=False, mode=None):
+ """
+ Return a copy of a PIL image as a numpy array.
+
+ This function is only available if Python Imaging Library (PIL) is installed.
+
+ Parameters
+ ----------
+ im : PIL image
+ Input image.
+ flatten : bool
+ If true, convert the output to grey-scale.
+ mode : str, optional
+ Mode to convert image to, e.g. ``'RGB'``. See the Notes of the
+ `imread` docstring for more details.
+
+ Returns
+ -------
+ fromimage : ndarray
+ The different colour bands/channels are stored in the
+ third dimension, such that a grey-image is MxN, an
+ RGB-image MxNx3 and an RGBA-image MxNx4.
+
+ """
+ if not Image.isImageType(im):
+ raise TypeError("Input is not a PIL image.")
+
+ if mode is not None:
+ if mode != im.mode:
+ im = im.convert(mode)
+ elif im.mode == 'P':
+ # Mode 'P' means there is an indexed "palette". If we leave the mode
+ # as 'P', then when we do `a = array(im)` below, `a` will be a 2-D
+ # containing the indices into the palette, and not a 3-D array
+ # containing the RGB or RGBA values.
+ if 'transparency' in im.info:
+ im = im.convert('RGBA')
+ else:
+ im = im.convert('RGB')
+
+ if flatten:
+ im = im.convert('F')
+ elif im.mode == '1':
+ # Workaround for crash in PIL. When im is 1-bit, the call array(im)
+ # can cause a seg. fault, or generate garbage. See
+ # https://github.com/scipy/scipy/issues/2138 and
+ # https://github.com/python-pillow/Pillow/issues/350.
+ #
+ # This converts im from a 1-bit image to an 8-bit image.
+ im = im.convert('L')
+
+ a = array(im)
+ return a
+
+
+_errstr = "Mode is unknown or incompatible with input array shape."
+
+
+def toimage(arr, high=255, low=0, cmin=None, cmax=None, pal=None,
+ mode=None, channel_axis=None):
+ """Takes a numpy array and returns a PIL image.
+
+ This function is only available if Python Imaging Library (PIL) is installed.
+
+ The mode of the PIL image depends on the array shape and the `pal` and
+ `mode` keywords.
+
+ For 2-D arrays, if `pal` is a valid (N,3) byte-array giving the RGB values
+ (from 0 to 255) then ``mode='P'``, otherwise ``mode='L'``, unless mode
+ is given as 'F' or 'I' in which case a float and/or integer array is made.
+
+ .. warning::
+
+ This function uses `bytescale` under the hood to rescale images to use
+ the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``.
+ It will also cast data for 2-D images to ``uint32`` for ``mode=None``
+ (which is the default).
+
+ Notes
+ -----
+ For 3-D arrays, the `channel_axis` argument tells which dimension of the
+ array holds the channel data.
+
+ For 3-D arrays if one of the dimensions is 3, the mode is 'RGB'
+ by default or 'YCbCr' if selected.
+
+ The numpy array must be either 2 dimensional or 3 dimensional.
+
+ """
+ data = asarray(arr)
+ if iscomplexobj(data):
+ raise ValueError("Cannot convert a complex-valued array.")
+ shape = list(data.shape)
+ valid = len(shape) == 2 or ((len(shape) == 3) and
+ ((3 in shape) or (4 in shape)))
+ if not valid:
+ raise ValueError("'arr' does not have a suitable array shape for "
+ "any mode.")
+ if len(shape) == 2:
+ shape = (shape[1], shape[0]) # columns show up first
+ if mode == 'F':
+ data32 = data.astype(numpy.float32)
+ image = Image.frombytes(mode, shape, data32.tostring())
+ return image
+ if mode in [None, 'L', 'P']:
+ bytedata = bytescale(data, high=high, low=low,
+ cmin=cmin, cmax=cmax)
+ image = Image.frombytes('L', shape, bytedata.tostring())
+ if pal is not None:
+ image.putpalette(asarray(pal, dtype=uint8).tostring())
+ # Becomes a mode='P' automagically.
+ elif mode == 'P': # default gray-scale
+ pal = (arange(0, 256, 1, dtype=uint8)[:, newaxis] *
+ ones((3,), dtype=uint8)[newaxis, :])
+ image.putpalette(asarray(pal, dtype=uint8).tostring())
+ return image
+ if mode == '1': # high input gives threshold for 1
+ bytedata = (data > high)
+ image = Image.frombytes('1', shape, bytedata.tostring())
+ return image
+ if cmin is None:
+ cmin = amin(ravel(data))
+ if cmax is None:
+ cmax = amax(ravel(data))
+ data = (data*1.0 - cmin)*(high - low)/(cmax - cmin) + low
+ if mode == 'I':
+ data32 = data.astype(numpy.uint32)
+ image = Image.frombytes(mode, shape, data32.tostring())
+ else:
+ raise ValueError(_errstr)
+ return image
+
+ # if here then 3-d array with a 3 or a 4 in the shape length.
+ # Check for 3 in datacube shape --- 'RGB' or 'YCbCr'
+ if channel_axis is None:
+ if (3 in shape):
+ ca = numpy.flatnonzero(asarray(shape) == 3)[0]
+ else:
+ ca = numpy.flatnonzero(asarray(shape) == 4)
+ if len(ca):
+ ca = ca[0]
+ else:
+ raise ValueError("Could not find channel dimension.")
+ else:
+ ca = channel_axis
+
+ numch = shape[ca]
+ if numch not in [3, 4]:
+ raise ValueError("Channel axis dimension is not valid.")
+
+ bytedata = bytescale(data, high=high, low=low, cmin=cmin, cmax=cmax)
+ if ca == 2:
+ strdata = bytedata.tobytes() # .tostring()
+ shape = (shape[1], shape[0])
+ elif ca == 1:
+ strdata = transpose(bytedata, (0, 2, 1)).tobytes() #.tostring()
+ shape = (shape[2], shape[0])
+ elif ca == 0:
+ strdata = transpose(bytedata, (1, 2, 0)).tobytes() #.tostring()
+ shape = (shape[2], shape[1])
+ else:
+ raise ValueError("Unexpected channel axis.")
+ if mode is None:
+ if numch == 3:
+ mode = 'RGB'
+ else:
+ mode = 'RGBA'
+
+ if mode not in ['RGB', 'RGBA', 'YCbCr', 'CMYK']:
+ raise ValueError(_errstr)
+
+ if mode in ['RGB', 'YCbCr']:
+ if numch != 3:
+ raise ValueError("Invalid array shape for mode.")
+ if mode in ['RGBA', 'CMYK']:
+ if numch != 4:
+ raise ValueError("Invalid array shape for mode.")
+
+ # Here we know data and mode is correct
+ image = Image.frombytes(mode, shape, strdata)
+ return image
+
+
+def imrotate(arr, angle, interp='bilinear'):
+ """
+ Rotate an image counter-clockwise by angle degrees.
+
+ This function is only available if Python Imaging Library (PIL) is installed.
+
+ .. warning::
+
+ This function uses `bytescale` under the hood to rescale images to use
+ the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``.
+ It will also cast data for 2-D images to ``uint32`` for ``mode=None``
+ (which is the default).
+
+ Parameters
+ ----------
+ arr : ndarray
+ Input array of image to be rotated.
+ angle : float
+ The angle of rotation.
+ interp : str, optional
+ Interpolation
+
+ - 'nearest' : for nearest neighbor
+ - 'bilinear' : for bilinear
+ - 'lanczos' : for lanczos
+ - 'cubic' : for bicubic
+ - 'bicubic' : for bicubic
+
+ Returns
+ -------
+ imrotate : ndarray
+ The rotated array of image.
+
+ """
+ arr = asarray(arr)
+ func = {'nearest': 0, 'lanczos': 1, 'bilinear': 2, 'bicubic': 3, 'cubic': 3}
+ im = toimage(arr)
+ im = im.rotate(angle, resample=func[interp])
+ return fromimage(im)
+
+
+def imresize(arr, size, interp='bilinear', mode=None):
+ """
+ Resize an image.
+
+ This function is only available if Python Imaging Library (PIL) is installed.
+
+ .. warning::
+
+ This function uses `bytescale` under the hood to rescale images to use
+ the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``.
+ It will also cast data for 2-D images to ``uint32`` for ``mode=None``
+ (which is the default).
+
+ Parameters
+ ----------
+ arr : ndarray
+ The array of image to be resized.
+ size : int, float or tuple
+ * int - Percentage of current size.
+ * float - Fraction of current size.
+ * tuple - Size of the output image (height, width).
+
+ interp : str, optional
+ Interpolation to use for re-sizing ('nearest', 'lanczos', 'bilinear',
+ 'bicubic' or 'cubic').
+ mode : str, optional
+ The PIL image mode ('P', 'L', etc.) to convert `arr` before resizing.
+ If ``mode=None`` (the default), 2-D images will be treated like
+ ``mode='L'``, i.e. casting to long integer. For 3-D and 4-D arrays,
+ `mode` will be set to ``'RGB'`` and ``'RGBA'`` respectively.
+
+ Returns
+ -------
+ imresize : ndarray
+ The resized array of image.
+
+ See Also
+ --------
+ toimage : Implicitly used to convert `arr` according to `mode`.
+ scipy.ndimage.zoom : More generic implementation that does not use PIL.
+
+ """
+ im = toimage(arr, mode=mode)
+ ts = type(size)
+ if issubdtype(ts, numpy.signedinteger):
+ percent = size / 100.0
+ size = tuple((array(im.size)*percent).astype(int))
+ elif issubdtype(type(size), numpy.floating):
+ size = tuple((array(im.size)*size).astype(int))
+ else:
+ size = (size[1], size[0])
+ func = {'nearest': 0, 'lanczos': 1, 'bilinear': 2, 'bicubic': 3, 'cubic': 3}
+ imnew = im.resize(size, resample=func[interp])
+ return fromimage(imnew)
diff --git a/src/stacked_hourglass/utils/transforms.py b/src/stacked_hourglass/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..7777e02f7a78e282c9032cb76325bafbbb16a5be
--- /dev/null
+++ b/src/stacked_hourglass/utils/transforms.py
@@ -0,0 +1,150 @@
+# Modified from:
+# https://github.com/anibali/pytorch-stacked-hourglass
+# https://github.com/bearpaw/pytorch-pose
+
+import numpy as np
+import torch
+
+from .imutils import im_to_numpy, im_to_torch
+from .misc import to_torch
+from .pilutil import imresize, imrotate
+
+
+def color_normalize(x, mean, std):
+ if x.size(0) == 1:
+ x = x.repeat(3, 1, 1)
+
+ for t, m, s in zip(x, mean, std):
+ t.sub_(m)
+ return x
+
+
+def flip_back(flip_output, hflip_indices):
+ """flip and rearrange output maps"""
+ return fliplr(flip_output)[:, hflip_indices]
+
+
+def shufflelr(x, width, hflip_indices):
+ """flip and rearrange coords"""
+ # Flip horizontal
+ x[:, 0] = width - x[:, 0]
+ # Change left-right parts
+ x = x[hflip_indices]
+ return x
+
+
+def fliplr(x):
+ """Flip images horizontally."""
+ if torch.is_tensor(x):
+ return torch.flip(x, [-1])
+ else:
+ return np.ascontiguousarray(np.flip(x, -1))
+
+
+def get_transform(center, scale, res, rot=0):
+ """
+ General image processing functions
+ """
+ # Generate transformation matrix
+ h = 200 * scale
+ t = np.zeros((3, 3))
+ t[0, 0] = float(res[1]) / h
+ t[1, 1] = float(res[0]) / h
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
+ t[2, 2] = 1
+ if not rot == 0:
+ rot = -rot # To match direction of rotation from cropping
+ rot_mat = np.zeros((3,3))
+ rot_rad = rot * np.pi / 180
+ sn,cs = np.sin(rot_rad), np.cos(rot_rad)
+ rot_mat[0,:2] = [cs, -sn]
+ rot_mat[1,:2] = [sn, cs]
+ rot_mat[2,2] = 1
+ # Need to rotate around center
+ t_mat = np.eye(3)
+ t_mat[0,2] = -res[1]/2
+ t_mat[1,2] = -res[0]/2
+ t_inv = t_mat.copy()
+ t_inv[:2,2] *= -1
+ t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
+ return t
+
+
+def transform(pt, center, scale, res, invert=0, rot=0, as_int=True):
+ # Transform pixel location to different reference
+ t = get_transform(center, scale, res, rot=rot)
+ if invert:
+ t = np.linalg.inv(t)
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
+ new_pt = np.dot(t, new_pt)
+ if as_int:
+ return new_pt[:2].astype(int) + 1
+ else:
+ return new_pt[:2] + 1
+
+
+
+def transform_preds(coords, center, scale, res):
+ # size = coords.size()
+ # coords = coords.view(-1, coords.size(-1))
+ # print(coords.size())
+ for p in range(coords.size(0)):
+ coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0))
+ return coords
+
+
+def crop(img, center, scale, res, rot=0, interp='bilinear'):
+ # import pdb; pdb.set_trace()
+ # mode = 'F'
+
+ img = im_to_numpy(img)
+
+ # Preprocessing for efficient cropping
+ ht, wd = img.shape[0], img.shape[1]
+ sf = scale * 200.0 / res[0]
+ if sf < 2:
+ sf = 1
+ else:
+ new_size = int(np.math.floor(max(ht, wd) / sf))
+ new_ht = int(np.math.floor(ht / sf))
+ new_wd = int(np.math.floor(wd / sf))
+ if new_size < 2:
+ return torch.zeros(res[0], res[1], img.shape[2]) \
+ if len(img.shape) > 2 else torch.zeros(res[0], res[1])
+ else:
+ img = imresize(img, [new_ht, new_wd], interp=interp) # , mode=mode)
+ center = center * 1.0 / sf
+ scale = scale / sf
+
+ # Upper left point
+ ul = np.array(transform([0, 0], center, scale, res, invert=1))
+ # Bottom right point
+ br = np.array(transform(res, center, scale, res, invert=1))
+
+ # Padding so that when rotated proper amount of context is included
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
+ if not rot == 0:
+ ul -= pad
+ br += pad
+
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
+ if len(img.shape) > 2:
+ new_shape += [img.shape[2]]
+ new_img = np.zeros(new_shape)
+
+ # Range to fill new array
+ new_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
+ new_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
+ # Range to sample from original image
+ old_x = max(0, ul[0]), min(img.shape[1], br[0])
+ old_y = max(0, ul[1]), min(img.shape[0], br[1])
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
+
+ if not rot == 0:
+ # Remove padding
+ new_img = imrotate(new_img, rot, interp=interp) # , mode=mode)
+ new_img = new_img[pad:-pad, pad:-pad]
+
+ new_img = im_to_torch(imresize(new_img, res, interp=interp)) #, mode=mode))
+ return new_img
diff --git a/src/stacked_hourglass/utils/visualization.py b/src/stacked_hourglass/utils/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..4487e7f10c348af91b3958081f6f029308440772
--- /dev/null
+++ b/src/stacked_hourglass/utils/visualization.py
@@ -0,0 +1,179 @@
+
+# Modified from:
+# https://github.com/anibali/pytorch-stacked-hourglass
+# https://github.com/bearpaw/pytorch-pose
+
+import matplotlib as mpl
+mpl.use('Agg')
+import matplotlib.pyplot as plt
+import numpy as np
+import cv2
+import torch
+
+# import stacked_hourglass.datasets.utils_stanext as utils_stanext
+# COLORS, labels = utils_stanext.load_keypoint_labels_and_colours()
+COLORS = ['#d82400', '#d82400', '#d82400', '#fcfc00', '#fcfc00', '#fcfc00', '#48b455', '#48b455', '#48b455', '#0090aa', '#0090aa', '#0090aa', '#d848ff', '#d848ff', '#fc90aa', '#006caa', '#d89000', '#d89000', '#fc90aa', '#006caa', '#ededed', '#ededed', '#a9d08e', '#a9d08e']
+RGB_MEAN = [0.4404, 0.4440, 0.4327]
+RGB_STD = [0.2458, 0.2410, 0.2468]
+
+
+
+def get_img_from_fig(fig, dpi=180):
+ buf = io.BytesIO()
+ fig.savefig(buf, format="png", dpi=dpi)
+ buf.seek(0)
+ img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
+ buf.close()
+ img = cv2.imdecode(img_arr, 1)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return img
+
+def save_input_image_with_keypoints(img, tpts, out_path='./test_input.png', colors=COLORS, rgb_mean=RGB_MEAN, rgb_std=RGB_STD, ratio_in_out=4., threshold=0.3, print_scores=False):
+ """
+ img has shape (3, 256, 256) and is a torch tensor
+ pts has shape (20, 3) and is a torch tensor
+ -> this function is tested with the mpii dataset and the results look ok
+ """
+ # reverse color normalization
+ for t, m, s in zip(img, rgb_mean, rgb_std): t.add_(m) # inverse to transforms.color_normalize()
+ img_np = img.detach().cpu().numpy().transpose(1, 2, 0)
+ # tpts_np = tpts.detach().cpu().numpy()
+ # plot image
+ fig, ax = plt.subplots()
+ plt.imshow(img_np) # plt.imshow(im)
+ plt.gca().set_axis_off()
+ plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
+ plt.margins(0,0)
+ # plot all visible keypoints
+ #import pdb; pdb.set_trace()
+
+ for idx, (x, y, v) in enumerate(tpts):
+ if v > threshold:
+ x = int(x*ratio_in_out)
+ y = int(y*ratio_in_out)
+ plt.scatter([x], [y], c=[colors[idx]], marker="x", s=50)
+ if print_scores:
+ txt = '{:2.2f}'.format(v.item())
+ plt.annotate(txt, (x, y)) # , c=colors[idx])
+
+ plt.savefig(out_path, bbox_inches='tight', pad_inches=0)
+
+ plt.close()
+ return
+
+
+
+def save_input_image(img, out_path, colors=COLORS, rgb_mean=RGB_MEAN, rgb_std=RGB_STD):
+ for t, m, s in zip(img, rgb_mean, rgb_std): t.add_(m) # inverse to transforms.color_normalize()
+ img_np = img.detach().cpu().numpy().transpose(1, 2, 0)
+ plt.imsave(out_path, img_np)
+ return
+
+######################################################################
+def get_bodypart_colors():
+ # body colors
+ n_body = 8
+ c = np.arange(1, n_body + 1)
+ norm = mpl.colors.Normalize(vmin=c.min(), vmax=c.max())
+ cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.gist_rainbow)
+ cmap.set_array([])
+ body_cols = []
+ for i in range(0, n_body):
+ body_cols.append(cmap.to_rgba(i + 1))
+ # head colors
+ n_blue = 5
+ c = np.arange(1, n_blue + 1)
+ norm = mpl.colors.Normalize(vmin=c.min()-1, vmax=c.max()+1)
+ cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Blues)
+ cmap.set_array([])
+ head_cols = []
+ for i in range(0, n_body):
+ head_cols.append(cmap.to_rgba(i + 1))
+ # torso colors
+ n_blue = 2
+ c = np.arange(1, n_blue + 1)
+ norm = mpl.colors.Normalize(vmin=c.min()-1, vmax=c.max()+1)
+ cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Greens)
+ cmap.set_array([])
+ torso_cols = []
+ for i in range(0, n_body):
+ torso_cols.append(cmap.to_rgba(i + 1))
+ return body_cols, head_cols, torso_cols
+body_cols, head_cols, torso_cols = get_bodypart_colors()
+tbp_dict = {'full_body': [0, 8],
+ 'head': [8, 13],
+ 'torso': [13, 15]}
+
+def save_image_with_part_segmentation(partseg_big, seg_big, input_image_np, ind_img, out_path_seg=None, out_path_seg_overlay=None, thr=0.3):
+ soft_max = torch.nn.Softmax(dim=0)
+ # create dit with results
+ tbp_dict_res = {}
+ for ind_tbp, part in enumerate(['full_body', 'head', 'torso']):
+ partseg_tbp = partseg_big[:, tbp_dict[part][0]:tbp_dict[part][1], :, :]
+ segm_img_pred = soft_max((partseg_tbp[ind_img, :, :, :])) # [1, :, :]
+ m_v, m_i = segm_img_pred.max(axis=0)
+ tbp_dict_res[part] = {
+ 'inds': tbp_dict[part],
+ 'seg_probs': segm_img_pred,
+ 'seg_max_inds': m_i,
+ 'seg_max_values': m_v}
+ # create output_image
+ partseg_image = np.zeros((256, 256, 3))
+ for ind_sp in range(0, 5):
+ # partseg_image[tbp_dict_res['head']['seg_max_inds']==ind_sp, :] = head_cols[ind_sp][0:3]
+ mask_a = tbp_dict_res['full_body']['seg_max_inds']==1
+ mask_b = tbp_dict_res['head']['seg_max_inds']==ind_sp
+ partseg_image[mask_a*mask_b, :] = head_cols[ind_sp][0:3]
+ for ind_sp in range(0, 2):
+ # partseg_image[tbp_dict_res['torso']['seg_max_inds']==ind_sp, :] = torso_cols[ind_sp][0:3]
+ mask_a = tbp_dict_res['full_body']['seg_max_inds']==2
+ mask_b = tbp_dict_res['torso']['seg_max_inds']==ind_sp
+ partseg_image[mask_a*mask_b, :] = torso_cols[ind_sp][0:3]
+ for ind_sp in range(0, 8):
+ if (not ind_sp == 1) and (not ind_sp == 2): # head and torso
+ partseg_image[tbp_dict_res['full_body']['seg_max_inds']==ind_sp, :] = body_cols[ind_sp][0:3]
+ partseg_image[soft_max((seg_big[ind_img, :, :, :]))[1, :, :]