Note
Go to the end to download the full example code.
CSI human reconstruction Tutorial¶
# !pip install pysensing
In this tutorial, we will be implementing codes for CSI human reconstruction task
import sys
sys.path.append('../..')
import pysensing.csi.dataset.get_dataloader as get_dataloader
import pysensing.csi.model.get_model as get_model
import pysensing.csi.inference.predict as predict
import pysensing.csi.inference.train as train
import pysensing.csi.inference.embedding as embedding
import torch
Load the data¶
CSI reconstruction dataset:
# HandFi
# CSI size : 6, 20, 114
# image : 144, 144
# joints2d : 2, 42
# joints3d : 2, 21
# train number : 3600
# test number : 400
train_loader, test_loader = get_dataloader.load_recon_dataset('HandFi', batch_size=32, return_train=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for data in train_loader:
((joints,image), csi) = data
joint = joints[:,:,0:21].to(device,dtype=torch.float)
img=image.to(device,dtype=torch.float)
csi=csi.to(device,dtype=torch.float)
joint2d = joint[:,0:2,:]
joint2d = joint2d.view(-1,42)
joint3d = joint[:,2,:]
print('data:', csi)
print('img:', img)
print('joint:', joint)
break
/data1/msc/zyj/yunjiao_csi/1028/yunjiao_csi/tutorials/csi_source/data/csi
using dataset: HandFi
data: tensor([[[[ -32.7930, -38.9111, -40.3795, ..., -48.4554, -45.7634,
-42.0926],
[ -52.3264, -56.3516, -54.7415, ..., -70.0369, -63.5968,
-61.1817],
[ -40.2269, -48.7751, -48.2723, ..., 52.2950, 48.2723,
45.7581],
...,
[ -30.3462, -28.4496, -27.5012, ..., -71.1239, -63.5373,
-61.6407],
[ -31.2848, -35.6300, -35.2823, ..., 42.5821, 39.9751,
36.4990],
[ -27.3576, -33.3301, -32.1741, ..., 40.6511, 36.6053,
35.2567]],
[[ -11.5020, -12.7257, -9.5442, ..., -11.5020, -9.7890,
-6.3628],
[ 8.0502, 12.0753, 16.9055, ..., -12.8804, -6.4402,
-10.4653],
[ -29.1645, -30.6730, -28.6617, ..., -42.7411, -40.7298,
-38.2156],
...,
[ 42.6743, 50.2609, 55.0025, ..., -2.8450, 0.0000,
-0.9483],
[ 5.7356, 7.4736, 7.1260, ..., -6.2570, -6.6046,
-3.4761],
[ -14.2568, -14.2568, -15.9907, ..., -15.2201, -16.1834,
-15.4128]],
[[ -8.8741, -10.4875, -5.1093, ..., -36.3030, -30.9248,
-26.6222],
[ -6.2461, -3.1231, -1.0410, ..., -64.0227, -51.5304,
-54.1330],
[ -52.8793, -58.3875, -55.6334, ..., -0.5508, -1.6525,
-3.8558],
...,
[ 43.7343, 49.8239, 59.7887, ..., -60.3423, -49.8239,
-45.9487],
[ 9.0769, 13.7834, 14.7920, ..., 11.7663, 10.4216,
13.1111],
[ -15.6624, -16.7303, -13.1707, ..., 4.2716, -0.3560,
-0.3560]],
[[ 48.4040, 50.5553, 55.3957, ..., 47.5972, 45.9838,
41.9501],
[ 71.3098, 77.5559, 75.9944, ..., 52.5715, 52.0509,
49.9689],
[ 40.2103, 47.3710, 49.0235, ..., -80.4206, -77.1156,
-74.3615],
...,
[ 55.3599, 54.2527, 55.3599, ..., 60.8959, 58.1279,
57.0207],
[ 41.6865, 47.0654, 46.0568, ..., -50.0910, -50.4272,
-45.0483],
[ 39.5120, 45.5633, 44.1395, ..., -53.7505, -48.4111,
-49.4789]],
[[ 14.7038, 17.5045, 16.8043, ..., 59.9821, 52.2801,
49.0126],
[ 23.4112, 23.4112, 28.0016, ..., 85.8411, 79.8735,
73.4469],
[ 19.3270, 23.3722, 27.4174, ..., -45.3961, -40.4519,
-35.9573],
...,
[ 8.1417, 10.5842, 8.1417, ..., 85.8948, 79.3814,
72.0539],
[ 9.0292, 12.8157, 12.2332, ..., -36.4082, -34.6606,
-32.9130],
[ 12.9031, 16.5897, 19.6618, ..., -28.8783, -23.9628,
-22.4268]],
[[ -7.4686, -6.3016, -3.7343, ..., -21.9390, -20.3052,
-23.3393],
[ -9.1809, -12.3942, -7.8037, ..., -8.2628, -14.2303,
-11.4761],
[ 8.0904, 11.2367, 8.5399, ..., 71.9146, 70.5662,
67.8694],
...,
[ -24.4251, -28.4959, -26.8676, ..., -25.2392, -25.2392,
-25.2392],
[ -10.4856, -12.2332, -15.4371, ..., 35.8257, 36.1169,
33.7868],
[ -1.5361, -2.1505, -3.0722, ..., 44.8535, 42.7030,
40.5525]]],
[[[ -42.5309, -49.6194, -49.0741, ..., -46.3478, -44.7120,
-43.0762],
[ 39.3843, 43.5416, 43.9792, ..., 47.2612, 43.9792,
39.3843],
[ 40.7355, 43.1317, 45.5279, ..., 11.9810, 11.9810,
10.7829],
...,
[ -54.7987, -63.9318, -65.4540, ..., 4.5666, -1.5222,
0.0000],
[ -39.2957, -43.6351, -43.8761, ..., 40.0189, 37.3670,
35.1973],
[ -5.2770, -10.5540, -14.6132, ..., 45.4632, 43.8395,
39.3744]],
[[ 11.4506, 15.8128, 17.9938, ..., -25.0823, -20.7202,
-17.9938],
[ -10.9401, -16.4101, -19.6922, ..., 1.0940, -2.4068,
-1.0940],
[ -52.7165, -62.3013, -59.9051, ..., 76.6786, 67.0937,
63.4994],
...,
[ -14.7145, -10.1479, -14.2071, ..., -67.4835, -62.4096,
-55.8135],
[ -13.9825, -16.6344, -15.6700, ..., -25.5542, -21.9381,
-20.9738],
[ -45.4632, -49.9283, -51.9579, ..., -25.5730, -24.7612,
-22.3257]],
[[ -5.4406, -4.0805, -1.0881, ..., -61.4789, -53.8621,
-50.5977],
[ -11.9511, -17.1606, -20.2249, ..., 11.6447, 4.2901,
5.2095],
[ -30.0826, -35.7585, -38.0289, ..., 105.5728, 96.4913,
90.8153],
...,
[ -33.7500, -38.3523, -36.3069, ..., -84.3751, -81.3069,
-76.1933],
[ -12.1761, -13.8365, -15.4969, ..., -25.7359, -25.7359,
-25.7359],
[ -44.9476, -51.5253, -54.2660, ..., -2.1926, -4.6592,
-7.3999]],
[[ 45.9732, 51.1418, 52.2299, ..., 50.5977, 47.6054,
44.3410],
[ -38.9177, -44.7400, -45.3529, ..., -69.5616, -62.5135,
-59.7555],
[ -58.4624, -62.4356, -65.2735, ..., 28.9474, 24.4066,
21.5686],
...,
[ 48.5796, 54.2046, 56.7614, ..., -48.0682, -40.9091,
-36.3069],
[ 40.4026, 45.9372, 44.5536, ..., -61.1574, -60.6039,
-57.0064],
[ -10.1406, -6.3036, -7.1258, ..., -75.3694, -72.3546,
-68.2436]],
[[ 30.7796, 32.5726, 33.7679, ..., 47.2153, 45.1235,
39.4457],
[ -26.0058, -26.9458, -25.0658, ..., -32.2722, -25.6925,
-25.0658],
[ -31.4767, -33.0247, -39.2169, ..., -38.1849, -35.0888,
-32.5087],
...,
[ 31.5271, 42.0361, 47.8745, ..., 21.6019, 22.1857,
21.0181],
[ 28.4140, 31.6697, 34.3335, ..., -9.1753, -11.5432,
-8.5834],
[ -1.8744, 1.5620, 4.0612, ..., -30.9273, -26.2414,
-20.3058]],
[[ -2.6895, -6.8731, -3.8848, ..., -0.8965, -1.7930,
-3.2871],
[ 14.7262, 15.9795, 19.4260, ..., 27.2591, 28.1991,
22.5592],
[ 32.5087, 38.1849, 36.6368, ..., -54.1812, -54.1812,
-46.9571],
...,
[ 12.2605, 14.0120, 15.1797, ..., 56.6320, 48.4583,
45.5391],
[ 4.7357, 2.0719, 3.5517, ..., 42.9169, 36.9973,
34.0375],
[ 32.1769, 36.8629, 37.1753, ..., 35.6133, 37.8001,
31.8645]]],
[[[ -3.9861, -1.1389, 0.0000, ..., -51.2495, -48.9718,
-44.9857],
[ 25.5935, 26.0810, 25.3498, ..., 9.2624, 10.4812,
8.5312],
[ -23.2937, -30.8484, -33.9962, ..., 69.8810, 62.9559,
61.0672],
...,
[ 48.3282, 53.0167, 55.1807, ..., 33.1805, 32.8199,
31.3773],
[ 28.6661, 34.0748, 37.8609, ..., -12.9809, -10.8174,
-9.7357],
[ -27.7503, -31.5518, -36.4936, ..., 53.9801, 50.5588,
46.7574]],
[[ 48.4023, 53.5273, 55.2356, ..., -14.8054, -14.8054,
-14.2360],
[ -34.3685, -39.9747, -43.3872, ..., 47.5309, 44.1184,
41.6809],
[ -54.7716, -61.0672, -62.9559, ..., 10.0729, 7.5547,
8.8138],
...,
[ 11.1804, 7.9345, 8.2951, ..., 45.8036, 37.5084,
35.3445],
[ 38.9426, 40.0243, 41.1061, ..., 54.6278, 49.2191,
45.4330],
[ -35.7333, -41.8156, -39.9149, ..., -2.6610, -2.6610,
-2.6610]],
[[ 65.3350, 68.0154, 71.7010, ..., -63.3247, -52.9381,
-46.5721],
[ -52.9388, -61.4945, -66.3071, ..., 80.7450, 67.9114,
62.0293],
[ -80.8251, -80.1985, -93.9826, ..., 70.8002, 55.1365,
53.2568],
...,
[ 27.1127, 22.5939, 22.0291, ..., 84.1624, 79.0787,
71.7357],
[ 57.4664, 61.8352, 57.4664, ..., 62.1712, 62.5073,
57.8024],
[ -43.5171, -34.2429, -37.0965, ..., 2.1402, -2.8536,
-7.1339]],
[[ 13.7371, 11.7268, 7.7062, ..., 58.6340, 61.9845,
53.6082],
[ -21.9241, -18.7157, -20.3199, ..., 3.7431, 1.0695,
-3.2084],
[ 25.6886, 31.9541, 26.3151, ..., -85.8375, -77.6923,
-78.9454],
...,
[ -57.6145, -63.8278, -75.1248, ..., -11.8618, -19.2048,
-15.2509],
[ -29.2373, -36.6306, -38.3109, ..., 56.7943, 50.4091,
44.3600],
[ 41.3769, 52.0778, 52.7912, ..., -71.3395, -70.6261,
-63.4921]],
[[ -1.5530, 0.0000, 0.9318, ..., -53.7336, -50.3171,
-46.5899],
[ 7.2149, 4.9758, 4.2294, ..., 30.1034, 29.8547,
27.6156],
[ -24.4503, -26.6334, -29.2531, ..., 67.6750, 60.6892,
58.5062],
...,
[ 39.7358, 41.7351, 45.2339, ..., 36.9868, 39.9857,
36.4870],
[ 22.5564, 30.2811, 33.9890, ..., -2.1629, 0.0000,
1.5450],
[ -34.4755, -37.2963, -42.3109, ..., 43.5645, 37.6097,
33.2219]],
[[ 40.9991, 45.0369, 46.9005, ..., -7.7650, -4.3484,
-4.0378],
[ -35.5768, -41.5477, -43.2893, ..., 37.5671, 33.3377,
31.8450],
[ -45.8444, -50.2105, -51.0837, ..., -5.2394, -7.4224,
-6.1126],
...,
[ 7.7472, 4.7483, 3.9986, ..., 34.9875, 33.2382,
27.7401],
[ 31.5171, 34.9160, 34.2980, ..., 52.8375, 50.6746,
47.2757],
[ -16.6109, -18.1780, -17.8646, ..., -29.1475, -27.8938,
-28.5207]]],
...,
[[[ -48.2071, -55.4837, -56.3932, ..., -13.6435, -13.6435,
-12.7340],
[ -52.5296, -64.3911, -65.5208, ..., 68.3449, 65.5208,
59.8724],
[ 7.8825, 8.3462, 3.9412, ..., 30.8344, 30.3707,
28.0523],
...,
[ -4.0329, 2.4197, 4.8395, ..., -67.7527, -61.3000,
-58.8803],
[ 4.0634, 7.4495, 10.8356, ..., -46.0514, -43.3425,
-39.2791],
[ 2.1584, 4.3168, 8.6337, ..., -8.6337, -9.7129,
-9.7129]],
[[ -48.2071, -47.2976, -49.1167, ..., -74.5846, -69.1272,
-64.5794],
[ -20.3340, -19.7692, -22.0285, ..., 7.3428, 4.5187,
9.6022],
[ -39.4124, -44.7447, -44.2810, ..., 37.5577, 35.7030,
34.0801],
...,
[ 56.4605, 62.9132, 64.5263, ..., -17.7447, -14.5184,
-13.7118],
[ 54.8553, 62.3048, 65.6910, ..., -54.1781, -49.4375,
-46.7286],
[ 59.3565, 64.7525, 65.8317, ..., -67.9902, -64.7525,
-60.4357]],
[[ -31.8529, -32.6492, -31.8529, ..., -78.8359, -78.8359,
-68.4837],
[ -2.5217, 0.8406, 1.6811, ..., -1.6811, -5.0433,
-0.8406],
[ -45.3756, -51.7747, -49.1569, ..., 36.6495, 34.3226,
31.4139],
...,
[ 66.3063, 69.7961, 67.0043, ..., -12.5633, -6.9796,
-3.4898],
[ 57.7810, 66.5738, 70.3421, ..., -59.6652, -50.2444,
-43.3358],
[ 60.5831, 69.0365, 68.3321, ..., -74.6721, -71.8543,
-64.8098]],
[[ 70.0763, 75.6506, 79.6322, ..., 27.0749, 23.0933,
25.4823],
[ 63.8823, 72.2878, 73.1284, ..., -77.3312, -70.6067,
-71.4473],
[ 4.0722, 7.2717, 9.8896, ..., -41.8852, -39.8491,
-38.9765],
...,
[ -21.6368, -27.9185, -25.8246, ..., 76.7758, 71.8900,
67.7023],
[ -25.7502, -30.7747, -34.5430, ..., 56.5249, 53.3846,
50.2444],
[ -21.1336, -30.2915, -30.2915, ..., 21.1336, 22.5425,
18.3158]],
[[ 23.7182, 32.5544, 34.4146, ..., 24.6483, 24.1832,
24.6483],
[ 33.4727, 42.2814, 39.9324, ..., -70.4689, -62.2476,
-62.8348],
[ 1.5144, 2.7763, 3.2811, ..., -45.1782, -44.6735,
-40.1304],
...,
[ 6.8301, 1.3660, 7.2855, ..., 69.2119, 61.4711,
57.8283],
[ 2.2891, 2.7469, 0.9156, ..., 53.1067, 46.6973,
42.1191],
[ 1.2312, 2.8728, 3.2831, ..., 20.5197, 17.2365,
13.5430]],
[[ 34.4146, 38.1351, 39.9954, ..., 69.2943, 64.6437,
57.2027],
[ 18.7917, 20.5534, 21.1407, ..., 4.1107, 7.6341,
2.3490],
[ 28.7727, 34.5778, 29.2775, ..., -20.9486, -19.9390,
-17.6675],
...,
[ -41.8914, -46.9001, -45.5341, ..., 7.7408, 3.1874,
5.4641],
[ -42.5769, -44.4082, -44.8660, ..., 48.0707, 40.2878,
38.4566],
[ -40.6289, -45.1432, -46.3744, ..., 66.8941, 61.5590,
58.2758]]],
[[[ -30.1405, -36.7125, -33.9930, ..., -16.0900, -13.5972,
-13.8238],
[ 31.3107, 28.8549, 29.4689, ..., 58.3238, 53.4123,
45.4312],
[ 12.5678, 16.8247, 15.0003, ..., -22.2978, -18.6490,
-17.0274],
...,
[ -8.7890, -10.8816, -5.4408, ..., -63.1970, -57.7562,
-56.0821],
[ 33.4186, 36.1100, 37.2314, ..., 57.6414, 53.3800,
47.9971],
[ 31.1475, 35.2324, 29.8709, ..., -32.4240, -29.6156,
-28.0838]],
[[ -21.7555, -22.2088, -23.1153, ..., -54.1622, -49.6298,
-46.2305],
[ -42.3615, -46.0451, -54.0262, ..., -51.5705, -52.7984,
-49.1148],
[ 34.8656, 38.7170, 37.2981, ..., 53.1092, 50.2713,
46.6226],
...,
[ 45.6190, 51.0598, 55.6636, ..., -33.0633, -30.9707,
-25.1114],
[ -15.0271, -16.8214, -23.9986, ..., -3.5886, -4.7100,
-5.3829],
[ -26.5519, -31.4028, -32.1687, ..., 54.3804, 52.5933,
46.7212]],
[[ -23.1707, -21.9512, -22.7642, ..., -65.0406, -59.7560,
-54.4715],
[ -38.7183, -45.0101, -45.4940, ..., 5.8077, -0.4840,
-1.4519],
[ 40.1074, 44.6479, 43.8911, ..., 42.3776, 37.8372,
38.2155],
...,
[ 50.8550, 59.7250, 60.3164, ..., -78.0565, -72.7345,
-66.8211],
[ -24.7524, -25.9501, -30.7409, ..., 29.9424, 24.3532,
19.1631],
[ -36.9102, -40.5210, -44.9342, ..., 35.7066, 35.7066,
31.2935]],
[[ 38.6178, 44.7154, 45.9349, ..., -13.4146, -14.6341,
-12.1951],
[ -55.1736, -54.2057, -52.7537, ..., -89.5361, -85.6643,
-79.8565],
[ -17.4051, -24.2158, -27.2428, ..., 48.4316, 46.9181,
44.6479],
...,
[ 26.0188, 27.2015, 26.0188, ..., 21.2881, 23.6535,
24.8362],
[ -38.7255, -43.1171, -41.9194, ..., -57.8887, -57.0902,
-50.7025],
[ -33.7007, -34.5031, -29.6887, ..., 61.7845, 59.7785,
57.7726]],
[[ -16.6859, -15.5482, -20.0989, ..., -33.7509, -31.0964,
-32.9925],
[ 3.7433, 2.8075, -1.8716, ..., 78.1413, 63.1681,
61.7644],
[ 17.1091, 17.8219, 17.4655, ..., -12.1189, -3.9208,
-5.7030],
...,
[ 5.6243, 6.8295, 8.0347, ..., -75.5260, -69.9017,
-65.0809],
[ 2.6111, 4.1032, 4.1032, ..., 67.8896, 61.9213,
56.3260],
[ -3.4724, -2.3149, -0.3858, ..., -27.3935, -22.7636,
-20.0629]],
[[ 5.3091, 9.1014, 0.3792, ..., -61.8135, -52.7121,
-48.1615],
[ -24.3314, -23.8635, -21.5240, ..., -59.4248, -59.4248,
-52.4062],
[ 2.4951, 4.9901, 6.4159, ..., 71.2877, 65.5847,
59.5252],
...,
[ 18.8815, 24.1040, 19.2832, ..., -45.7977, -42.9855,
-35.3526],
[ -16.7859, -17.5319, -18.6510, ..., -22.3812, -20.8891,
-19.0240],
[ -19.2912, -20.0629, -18.9054, ..., 74.8499, 67.9051,
62.1177]]],
[[[ 46.8680, 56.0261, 57.6423, ..., -66.8004, -59.2584,
-58.7197],
[ 36.8100, 37.9749, 41.0035, ..., 57.7777, 50.7885,
47.5268],
[ -62.3675, -76.4504, -72.4267, ..., 22.1304, 16.0948,
18.1067],
...,
[ 14.1170, 22.5872, 22.5872, ..., -79.0551, -73.4083,
-64.9382],
[ 47.5145, 50.0928, 52.3027, ..., 61.1427, 55.6177,
52.3027],
[ 47.1487, 52.6040, 57.6695, ..., -58.4489, -52.2143,
-48.7074]],
[[ 46.3293, 51.7164, 51.7164, ..., 44.7132, 42.0196,
37.1712],
[ -27.4910, -33.7813, -41.4695, ..., 8.8530, 6.7563,
7.6882],
[ -44.2608, -46.2726, -48.2845, ..., -90.5334, -82.4860,
-72.4267],
...,
[ 79.0551, 93.1721, 95.9955, ..., -62.1148, -59.2914,
-53.6446],
[ -25.7830, -31.3080, -33.8863, ..., 19.5214, 17.3115,
14.3648],
[ 23.7692, 25.3278, 27.2761, ..., -31.5624, -28.0555,
-25.7175]],
[[ 47.9657, 47.9657, 48.8707, ..., 32.5804, 39.8205,
35.2955],
[ -33.4600, -40.1521, -45.3893, ..., 32.2962, 25.6042,
20.9489],
[ -44.3696, -45.0127, -40.5114, ..., -100.3139, -99.6709,
-88.0962],
...,
[ 89.7694, 102.5936, 100.9906, ..., -100.9906, -92.9754,
-81.7543],
[ -6.2048, -12.7974, -11.2462, ..., 62.0480, 51.9652,
49.2506],
[ 43.2276, 44.4510, 45.6744, ..., -70.9585, -63.6180,
-60.7633]],
[[ -54.3007, -66.9709, -65.1609, ..., 95.0263, 84.1661,
88.6912],
[ -37.8244, -37.5334, -39.5701, ..., -66.0472, -59.6462,
-55.8637],
[ 71.3772, 85.5240, 81.0228, ..., -50.8000, -46.2987,
-43.0835],
...,
[ -22.4424, -27.2514, -32.0605, ..., 81.7543, 78.5482,
68.9301],
[ -56.6188, -59.7212, -62.4358, ..., -53.5164, -50.8018,
-46.5360],
[ -38.3339, -40.7807, -44.0432, ..., 41.1886, 39.5573,
35.8871]],
[[ 31.6533, 33.8362, 38.7480, ..., -48.0256, -40.3852,
-37.1107],
[ 11.5907, 9.2726, 8.6930, ..., 49.8401, 44.0447,
42.3061],
[ -34.1257, -44.3635, -49.4823, ..., -14.2191, -15.3566,
-19.3379],
...,
[ 17.7250, 15.5980, 16.3070, ..., -98.5508, -90.0428,
-83.6618],
[ 26.2991, 26.2991, 29.8706, ..., 65.5855, 58.4425,
53.8970],
[ 29.6106, 31.0551, 30.6940, ..., -64.9990, -56.6935,
-55.6102]],
[[ 24.5586, 26.1958, 26.7415, ..., 66.5810, 60.5778,
56.2118],
[ -26.3689, -31.5847, -31.0052, ..., -28.9768, -26.0791,
-23.7610],
[ -18.7692, -22.7505, -26.7318, ..., -90.4332, -83.0393,
-73.9391],
...,
[ 46.7939, 53.1749, 53.8839, ..., -29.0689, -23.3969,
-27.6509],
[ -20.1302, -22.4030, -19.8055, ..., -3.2468, -5.1949,
-2.5974],
[ 12.2776, 14.0831, 10.8332, ..., -11.1943, -7.5832,
-5.4166]]]], device='cuda:0')
img: tensor([[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
...,
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]], device='cuda:0')
joint: tensor([[[0.4065, 0.3842, 0.2791, ..., 0.4321, 0.4197, 0.4292],
[0.3706, 0.5283, 0.4555, ..., 0.3087, 0.3569, 0.3985],
[0.8637, 0.8774, 0.8543, ..., 0.7560, 0.7454, 0.7595]],
[[0.3400, 0.3053, 0.2496, ..., 0.3744, 0.3568, 0.3587],
[0.3861, 0.5380, 0.4331, ..., 0.3324, 0.3786, 0.4183],
[0.8486, 0.8448, 0.8010, ..., 0.7469, 0.7428, 0.7609]],
[[0.3652, 0.3266, 0.2775, ..., 0.3989, 0.3775, 0.3696],
[0.3761, 0.5271, 0.4217, ..., 0.3347, 0.3765, 0.4194],
[0.8042, 0.8070, 0.7568, ..., 0.7023, 0.6862, 0.6876]],
...,
[[0.3437, 0.3123, 0.2372, ..., 0.4320, 0.4302, 0.4275],
[0.4288, 0.5759, 0.4799, ..., 0.3065, 0.2592, 0.2173],
[0.8367, 0.8344, 0.8248, ..., 0.8529, 0.8606, 0.8648]],
[[0.3488, 0.3174, 0.2413, ..., 0.4052, 0.3935, 0.3941],
[0.3413, 0.4948, 0.3969, ..., 0.3103, 0.3587, 0.3945],
[0.7915, 0.8000, 0.7705, ..., 0.6966, 0.6999, 0.7255]],
[[0.3609, 0.3908, 0.2932, ..., 0.3134, 0.2729, 0.2739],
[0.3475, 0.4973, 0.4819, ..., 0.2110, 0.2381, 0.2809],
[0.8023, 0.7771, 0.8549, ..., 0.7438, 0.7365, 0.7290]]],
device='cuda:0')
Load the model¶
Model zoo: AutoEncoder
model = get_model.load_recon_model('HandFi', 'AutoEncoder')
print(model)
CSI_AutoEncoder(
(encoder): Sequential(
(0): Conv2d(6, 3, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), groups=3, bias=False)
(1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): LeakyReLU(negative_slope=0.3, inplace=True)
)
(decoder): Sequential(
(0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): ReLU()
(2): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(3): ReLU()
(4): ConvTranspose2d(32, 16, kernel_size=(4, 3), stride=(2, 2), padding=(1, 1))
(5): ReLU()
(6): ConvTranspose2d(16, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(7): ReLU()
(8): ConvTranspose2d(3, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): Tanh()
)
(fc): Sequential(
(0): Linear(in_features=13680, out_features=12544, bias=True)
(1): ReLU()
)
(dropout): Dropout2d(p=0.5, inplace=False)
(fj): Sequential(
(0): Linear(in_features=28160, out_features=2048, bias=True)
(1): ReLU()
)
(joint): Joints(
(joints): Sequential(
(0): Linear(in_features=2048, out_features=1024, bias=True)
(1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.01)
(3): Linear(in_features=1024, out_features=512, bias=True)
(4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): LeakyReLU(negative_slope=0.01)
(6): Linear(in_features=512, out_features=64, bias=True)
(7): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): LeakyReLU(negative_slope=0.01)
)
(linear1): Linear(in_features=64, out_features=42, bias=True)
(linear2): Linear(in_features=64, out_features=21, bias=True)
)
(decmask): ResNet18Dec(
(linear): Linear(in_features=2048, out_features=512, bias=True)
(layer4): Sequential(
(0): BasicBlockDec(
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential()
)
(1): BasicBlockDec(
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): ResizeConv2d(
(conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential(
(0): ResizeConv2d(
(conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(layer3): Sequential(
(0): BasicBlockDec(
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential()
)
(1): BasicBlockDec(
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): ResizeConv2d(
(conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential(
(0): ResizeConv2d(
(conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(layer2): Sequential(
(0): BasicBlockDec(
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential()
)
(1): BasicBlockDec(
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): ResizeConv2d(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential(
(0): ResizeConv2d(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(layer1): Sequential(
(0): BasicBlockDec(
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential()
)
(1): BasicBlockDec(
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential()
)
)
(conv2): Conv2d(32, 16, kernel_size=(5, 5), stride=(1, 1))
(conv1): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1))
(conv3): Conv2d(16, 4, kernel_size=(5, 5), stride=(1, 1))
(conv4): Conv2d(4, 2, kernel_size=(5, 5), stride=(1, 1))
(conv5): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1))
(score_fr): Conv2d(4096, 2, kernel_size=(1, 1), stride=(1, 1))
(upscore): ConvTranspose2d(64, 64, kernel_size=(8, 8), stride=(4, 4), bias=False)
)
(incep): InceptionV3(
(Conv2d_1a_3x3): ConvBN(
(0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(Conv2d_2a_3x3): ConvBN(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(Conv2d_2b_3x3): ConvBN(
(0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(Mixed_5b): InceptionA(
(branch1x1): BasicConv2d(
(conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(branch5x5): Sequential(
(0): ConvBN(
(0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): ConvBN(
(0): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(branch3x3): Sequential(
(0): ConvBN(
(0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): ConvBN(
(0): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(2): ConvBN(
(0): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(branchpool): Sequential(
(0): AvgPool2d(kernel_size=1, stride=1, padding=0)
(1): ConvBN(
(0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
)
(Mixed_5c): InceptionA(
(branch1x1): BasicConv2d(
(conv): Conv2d(176, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(branch5x5): Sequential(
(0): ConvBN(
(0): Conv2d(176, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): ConvBN(
(0): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(branch3x3): Sequential(
(0): ConvBN(
(0): Conv2d(176, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): ConvBN(
(0): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(2): ConvBN(
(0): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(branchpool): Sequential(
(0): AvgPool2d(kernel_size=1, stride=1, padding=0)
(1): ConvBN(
(0): Conv2d(176, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
)
(Mixed_5d): InceptionA(
(branch1x1): BasicConv2d(
(conv): Conv2d(160, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(branch5x5): Sequential(
(0): ConvBN(
(0): Conv2d(160, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): ConvBN(
(0): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(branch3x3): Sequential(
(0): ConvBN(
(0): Conv2d(160, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(1): ConvBN(
(0): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(2): ConvBN(
(0): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(branchpool): Sequential(
(0): AvgPool2d(kernel_size=1, stride=1, padding=0)
(1): ConvBN(
(0): Conv2d(160, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
)
(Mixed_6a): Conv2d(152, 312, kernel_size=(5, 6), stride=(2, 2), padding=(1, 1))
(Mixed_6b): InceptionC(
(branch1x1): BasicConv2d(
(conv): Conv2d(312, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(branch7x7): Sequential(
(0): BasicConv2d(
(conv): Conv2d(312, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): BasicConv2d(
(conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): BasicConv2d(
(conv): Conv2d(128, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(branch7x7stack): Sequential(
(0): BasicConv2d(
(conv): Conv2d(312, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): BasicConv2d(
(conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): BasicConv2d(
(conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): BasicConv2d(
(conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(4): BasicConv2d(
(conv): Conv2d(128, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(branch_pool): Sequential(
(0): AvgPool2d(kernel_size=3, stride=1, padding=1)
(1): BasicConv2d(
(conv): Conv2d(312, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
)
(Mixed_6c): InceptionC(
(branch1x1): BasicConv2d(
(conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(branch7x7): Sequential(
(0): BasicConv2d(
(conv): Conv2d(256, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): BasicConv2d(
(conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): BasicConv2d(
(conv): Conv2d(160, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(branch7x7stack): Sequential(
(0): BasicConv2d(
(conv): Conv2d(256, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): BasicConv2d(
(conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): BasicConv2d(
(conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): BasicConv2d(
(conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(4): BasicConv2d(
(conv): Conv2d(160, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(branch_pool): Sequential(
(0): AvgPool2d(kernel_size=3, stride=1, padding=1)
(1): BasicConv2d(
(conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
)
(Mixed_6d): InceptionC(
(branch1x1): BasicConv2d(
(conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(branch7x7): Sequential(
(0): BasicConv2d(
(conv): Conv2d(256, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): BasicConv2d(
(conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): BasicConv2d(
(conv): Conv2d(160, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(branch7x7stack): Sequential(
(0): BasicConv2d(
(conv): Conv2d(256, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): BasicConv2d(
(conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): BasicConv2d(
(conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): BasicConv2d(
(conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(4): BasicConv2d(
(conv): Conv2d(160, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(branch_pool): Sequential(
(0): AvgPool2d(kernel_size=3, stride=1, padding=1)
(1): BasicConv2d(
(conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
)
(conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
(conv2): ConvBN(
(0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
)
Model train¶
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epoch_num = 1
train.recon_train(train_loader, model, epoch_num, optimizer, device)
Epoch:1, Loss:44.061765085
mPA: 0.949 | => IoU: 0.000 | => mpjpe: 0.439 | =>pck: 0.000
Model inference¶
model = get_model.load_pretrain(model, 'HandFi', 'AutoEncoder', device=device)
output = predict.recon_predict(csi, 'HandFi', model, device)
_, mask, twod, threed = output
print("mask:", mask.shape)
print("twod:", twod.shape)
print("threed:", threed.shape)
mask: torch.Size([32, 1, 114, 114])
twod: torch.Size([32, 42])
threed: torch.Size([32, 21])
Evaluate loss¶
IoUerr = train.IoU(img,mask)
mPAerr = train.mPA(img,mask)
mpjpe, pck = train.mpjpe_pck(joint2d,joint3d, twod, threed)
print( f'mPA: {mPAerr:.3f} | => IoU: {IoUerr:.3f} | => mpjpe: {mpjpe:.3f} | =>pck: {pck:.3f}\n')
mPA: 0.998 | => IoU: 0.139 | => mpjpe: 0.011 | =>pck: 1.000
Generate embeddings¶
csi_embedding = embedding.recon_csi_embedding(csi, 'HandFi', model, device)
print('csi_embedding: ', csi_embedding)
csi_embedding: tensor([[[[-7.5988e-02, -1.3345e-01, -7.9132e-02, ..., -1.2570e-01,
-3.0211e-02, 8.9611e-02],
[-7.8937e-02, -7.9026e-02, 1.0359e-02, ..., -1.9818e-02,
1.3651e-01, 2.5493e-01],
[-6.4701e-02, -1.0918e-01, -8.3367e-02, ..., 3.1941e-01,
-1.4069e-02, -1.2549e-01],
...,
[-1.9620e-02, -2.9558e-02, 9.6607e-02, ..., -3.6337e-01,
-2.3918e-01, -2.0665e-01],
[-2.5418e-02, 6.4009e-02, 2.6116e-02, ..., -3.9826e-03,
4.3250e-02, 1.4635e-01],
[-7.6564e-02, -1.0956e-01, -1.9091e-01, ..., 2.4052e-01,
1.4318e-01, 8.8810e-02]],
[[-6.8691e-02, -1.3890e-01, -1.5969e-01, ..., -1.3222e-01,
-8.6994e-02, -7.1656e-03],
[-9.5087e-02, -1.7732e-01, -1.2714e-01, ..., -6.6031e-02,
-1.4130e-02, 3.1650e-01],
[-1.2589e-01, -2.6198e-01, -2.1262e-01, ..., 9.4285e-01,
7.7951e-01, 6.6286e-01],
...,
[-1.1734e-01, -2.0700e-01, -1.7124e-01, ..., 1.7257e-01,
-2.0734e-02, -2.0480e-02],
[-3.7680e-02, -4.1951e-02, 8.4990e-02, ..., -1.9991e-02,
-2.5701e-02, -6.3073e-02],
[ 7.5083e-03, -2.4141e-02, -2.7933e-02, ..., 1.0905e-01,
4.7774e-02, -1.5171e-02]],
[[ 9.8348e-02, 1.9889e-01, 3.1992e-01, ..., 2.6857e-01,
1.6400e-01, 4.3267e-02],
[ 5.6098e-02, 1.7994e-01, 5.7166e-02, ..., -9.1950e-02,
-6.3024e-02, -6.6400e-02],
[-5.9416e-03, -2.6637e-02, -1.1798e-01, ..., -6.1100e-01,
-4.4403e-01, -2.4352e-01],
...,
[ 4.9742e-02, -7.8758e-03, -2.8914e-03, ..., -1.6821e-01,
-2.0014e-01, -6.8226e-02],
[-3.1435e-02, -8.7355e-02, -1.8312e-01, ..., 5.9714e-01,
4.9378e-01, 2.9055e-01],
[ 4.0038e-02, -2.1485e-03, -7.1473e-02, ..., 9.5176e-02,
1.7183e-01, 6.3368e-02]],
...,
[[-1.9152e-03, -5.3906e-02, -3.0001e-02, ..., -8.9673e-02,
-7.1675e-02, -3.1333e-02],
[-5.0727e-02, -8.1216e-02, -2.1506e-02, ..., -3.3926e-02,
-2.7612e-03, 1.0606e-01],
[ 3.5804e-02, -2.7429e-02, 5.7590e-01, ..., 1.0959e+00,
1.0089e+00, 7.1672e-01],
...,
[-3.4329e-02, -8.6428e-02, -6.2210e-02, ..., 6.7720e-01,
3.4152e-01, 2.8824e-01],
[-5.4396e-02, -3.1631e-02, 2.2842e-01, ..., -2.0054e-01,
-1.6767e-01, -1.4687e-01],
[ 6.9414e-02, 1.4295e-01, 4.5608e-01, ..., -1.3391e-01,
-9.4052e-02, -8.8097e-02]],
[[-5.4808e-02, -7.3786e-02, -4.5198e-02, ..., -1.1224e-01,
-8.8538e-03, 1.2416e-01],
[-1.0270e-01, -1.5658e-01, -1.4829e-01, ..., -4.0651e-01,
-1.6581e-01, -2.9898e-02],
[-1.3807e-01, -2.1471e-01, -2.1912e-01, ..., -1.7778e-01,
-1.2842e-01, -1.0364e-01],
...,
[-1.2237e-01, -1.4201e-01, -1.7165e-01, ..., 1.0059e+00,
6.8678e-01, 3.9695e-01],
[-5.4163e-02, -7.8612e-02, -9.3758e-02, ..., 2.8166e-01,
1.9588e-01, 9.0473e-02],
[ 3.5573e-02, -1.4971e-02, 6.5952e-02, ..., 9.5610e-02,
2.6193e-02, -9.3885e-04]],
[[-1.2038e-02, 2.6270e-01, -1.5770e-02, ..., 1.4886e-01,
-1.1121e-03, -1.4652e-01],
[ 3.5575e-01, 6.8593e-01, 8.8052e-02, ..., 7.7226e-02,
-1.5856e-01, -3.2138e-01],
[ 6.0042e-01, 6.7811e-01, -7.3340e-02, ..., -2.6675e-01,
-5.4458e-01, -2.0304e-01],
...,
[ 2.9060e-01, 6.5125e-01, 2.5959e-01, ..., 2.2907e-01,
1.8849e-02, 3.0123e-01],
[ 3.3442e-01, 2.3328e-01, -3.7865e-02, ..., -5.5345e-02,
2.5826e-01, 2.3314e-02],
[-6.4319e-03, -1.4516e-02, -4.3971e-02, ..., -1.0781e-01,
-5.1529e-02, -2.8931e-02]]],
[[[-1.1201e-01, -1.6031e-01, -1.3094e-01, ..., -1.1048e-01,
-2.1807e-02, 1.9078e-01],
[ 2.8870e-01, 5.8584e-01, 5.4395e-01, ..., 6.9207e-01,
2.9979e-01, -5.5536e-02],
[-4.2719e-02, -7.1615e-02, -1.1386e-01, ..., -3.8829e-01,
-2.9365e-01, -2.7612e-01],
...,
[-6.9176e-02, -9.9855e-02, -1.3018e-01, ..., -1.8537e-01,
-1.2654e-01, -1.1904e-01],
[-1.2052e-01, -1.0508e-01, -1.2950e-01, ..., -1.0749e-01,
-2.7437e-02, 4.0822e-02],
[-8.3533e-02, -1.1825e-01, -1.8019e-01, ..., 3.3185e-01,
2.4575e-01, 1.3988e-01]],
[[-2.4792e-02, -8.9631e-02, -9.7114e-02, ..., -2.0839e-01,
-1.2118e-01, -1.7923e-02],
[-8.2709e-03, 3.7649e-02, 2.7741e-01, ..., 9.4608e-01,
8.0376e-01, 4.5409e-01],
[-2.2271e-02, -1.1002e-02, -2.5261e-02, ..., -7.4665e-03,
-7.9154e-02, -8.4814e-02],
...,
[-6.5269e-02, -1.4813e-01, -7.1302e-02, ..., 6.2216e-01,
1.7507e-01, -1.1795e-02],
[-1.1235e-02, -7.7565e-03, 1.5451e-01, ..., -1.1655e-02,
-5.2153e-02, -7.0035e-02],
[ 1.0802e-01, 1.0302e-01, 2.0105e-02, ..., 1.5653e-01,
5.5474e-02, -1.8608e-02]],
[[ 7.8802e-02, 2.4965e-01, 3.1480e-01, ..., 4.1662e-01,
2.4005e-01, 1.0356e-01],
[-8.9133e-02, -1.4291e-01, -2.2279e-01, ..., -4.0896e-01,
-2.5937e-01, -1.5681e-01],
[-2.7849e-03, -1.2922e-01, -1.2820e-01, ..., 1.7465e-01,
-9.0056e-02, -9.5152e-02],
...,
[-2.1662e-02, -1.1998e-01, -2.7085e-01, ..., -2.0106e-01,
-1.5820e-01, -8.5378e-02],
[ 5.4356e-02, -5.8501e-02, -1.7795e-01, ..., 5.3807e-01,
3.7422e-01, 1.9366e-01],
[-1.4098e-02, 1.0553e-02, -6.3063e-02, ..., 1.3944e-01,
2.4085e-01, 8.4258e-02]],
...,
[[-2.8043e-02, -4.8503e-02, -6.5487e-02, ..., -6.5838e-02,
-7.8326e-02, -6.3266e-03],
[-2.6876e-02, 1.5565e-04, 2.2422e-01, ..., 2.1512e-01,
4.5621e-01, 1.7036e-01],
[ 5.3850e-02, -1.9424e-02, 2.1780e-01, ..., 9.0915e-01,
4.7735e-01, 1.8738e-01],
...,
[ 5.9052e-02, -2.4386e-03, 7.7303e-01, ..., 4.6570e-01,
2.9815e-01, 1.7371e-01],
[-3.9974e-02, -2.1020e-02, 4.1256e-01, ..., -1.2656e-01,
-1.3295e-01, -9.9367e-02],
[ 7.0246e-02, 1.6595e-01, 4.4544e-01, ..., -1.5969e-01,
-1.2156e-01, -1.0217e-01]],
[[-5.4430e-02, -1.0481e-01, -8.6851e-02, ..., -1.0888e-01,
4.6569e-02, 1.6418e-01],
[ 6.2495e-02, 4.5346e-02, -8.1779e-02, ..., -1.3841e-01,
-1.1162e-01, -9.3298e-02],
[ 3.3360e-02, 3.5860e-01, 6.7352e-01, ..., 1.1239e+00,
3.5015e-01, 3.1024e-01],
...,
[-1.0170e-01, -1.4160e-01, -1.1348e-01, ..., 8.1138e-01,
4.3449e-01, 2.7265e-01],
[-3.6892e-02, -4.0478e-02, 9.2279e-02, ..., 6.6005e-01,
4.1285e-01, 2.5957e-01],
[ 1.7225e-01, 7.8906e-02, 2.1531e-01, ..., 1.4322e-01,
3.0350e-02, 2.1241e-03]],
[[-1.7732e-02, 4.2032e-01, 2.8272e-01, ..., 2.7503e-02,
-4.2699e-02, -2.0023e-01],
[ 1.2115e-01, -1.2649e-01, -1.9780e-01, ..., -1.9010e-01,
-3.5124e-01, -1.0214e-01],
[-1.1262e-02, -1.1930e-01, -9.5240e-02, ..., 3.6433e-01,
4.9932e-01, 6.3529e-01],
...,
[ 7.0015e-01, 5.2267e-01, -6.4452e-02, ..., -2.0909e-01,
-1.6590e-01, 5.3467e-02],
[ 4.5025e-01, 3.3467e-01, 4.1305e-02, ..., -1.5154e-02,
2.3561e-01, 6.5215e-02],
[-5.5921e-02, -1.0450e-01, -1.3866e-01, ..., -1.3112e-01,
-5.6114e-02, -2.3075e-02]]],
[[[-5.7434e-02, -1.4914e-02, -8.9773e-02, ..., 1.8952e-01,
8.1011e-02, 2.0715e-01],
[ 7.4782e-02, -4.0186e-02, -5.4041e-02, ..., 6.0871e-03,
1.6597e-01, -8.7688e-02],
[ 5.6568e-02, 5.9086e-02, 9.1099e-01, ..., -4.8499e-01,
-4.4338e-01, -3.3071e-01],
...,
[ 1.3687e-01, 8.0689e-03, 5.3674e-01, ..., 6.4444e-01,
3.5084e-01, -1.7442e-02],
[-2.8487e-02, 4.3667e-01, 1.0466e+00, ..., -8.0288e-02,
-6.0873e-02, -2.5624e-02],
[-6.0843e-02, -1.5800e-01, -2.7730e-01, ..., 4.4224e-01,
3.7650e-01, 1.3905e-01]],
[[ 3.1156e-01, 4.9859e-01, 6.1413e-01, ..., -3.9132e-01,
-1.9140e-01, -7.5594e-02],
[-2.3902e-02, -1.2193e-01, -1.9248e-01, ..., 1.0932e+00,
7.5196e-01, 2.6939e-01],
[-2.4425e-01, -3.3101e-01, -3.3450e-01, ..., 9.1637e-01,
2.0239e-01, 7.4754e-02],
...,
[-7.8633e-02, -2.2259e-01, -3.1126e-01, ..., 1.0238e+00,
4.6807e-01, 2.9726e-01],
[-4.1057e-02, 2.1352e-01, 5.4933e-01, ..., -5.1812e-02,
-4.6104e-02, 1.4153e-02],
[-2.2941e-02, -4.0784e-02, -9.4297e-02, ..., 2.7377e-01,
1.2808e-01, 4.1316e-02]],
[[-3.1916e-02, -1.9901e-02, -7.5594e-02, ..., 9.6380e-01,
4.4228e-01, 1.6175e-01],
[-8.1777e-03, 1.1186e-01, 4.1587e-01, ..., -3.8243e-01,
-2.3134e-01, -1.2859e-01],
[ 2.8381e-01, 1.7108e-01, 5.4225e-01, ..., -2.7967e-01,
-1.3795e-01, -1.2494e-01],
...,
[ 3.3475e-01, 6.1965e-01, 1.2893e+00, ..., -2.1037e-01,
-5.8931e-02, 1.0094e-01],
[-8.0970e-02, -2.0552e-01, -1.9528e-01, ..., 4.4057e-01,
1.8729e-01, 1.7780e-01],
[ 1.6259e-02, 1.8337e-01, -1.5995e-02, ..., -1.0156e-02,
5.9640e-02, 7.7374e-02]],
...,
[[-6.6453e-02, 1.2345e-03, -6.8085e-02, ..., 3.7254e-01,
1.4041e-02, 3.1234e-01],
[ 2.8949e-01, 1.6551e-01, 1.7720e-01, ..., -2.0364e-01,
3.0634e-03, -1.0169e-01],
[-2.4140e-02, -1.9360e-01, -1.4111e-01, ..., 4.5580e-01,
4.3118e-01, -6.7965e-02],
...,
[ 5.0570e-01, -1.4967e-03, -3.4613e-02, ..., -3.0379e-01,
-2.3738e-01, -2.0096e-01],
[-2.1262e-01, -2.7144e-01, -3.7416e-01, ..., -6.5362e-02,
-1.0723e-01, -7.3084e-02],
[ 1.1407e-01, 3.3217e-01, 7.4943e-01, ..., -1.7954e-01,
-1.4103e-01, -9.4326e-02]],
[[ 2.0149e-01, 3.7478e-02, -1.5635e-02, ..., 4.2450e-01,
3.9457e-01, 1.9363e-01],
[ 9.5819e-02, 8.9870e-02, -1.9502e-02, ..., 5.8218e-02,
1.0696e-01, 1.4387e-01],
[-1.1390e-01, 1.1920e-01, 4.2096e-01, ..., -4.6602e-01,
-4.1030e-01, -1.8032e-01],
...,
[ 3.4323e-01, 7.0066e-01, 7.1163e-01, ..., 1.7737e-02,
-5.3021e-02, -9.1859e-02],
[-2.7616e-02, 4.4884e-01, 8.0543e-01, ..., -5.2696e-02,
-1.3554e-02, -2.4830e-02],
[-2.1704e-02, -1.2163e-01, -1.4341e-01, ..., 4.8790e-01,
2.7137e-01, 8.0601e-02]],
[[-7.6124e-02, -5.0089e-02, 3.1880e-01, ..., -1.8540e-01,
-2.2487e-01, -2.1699e-01],
[-1.2993e-01, -1.1733e-01, -1.5225e-01, ..., 4.1419e-01,
1.7313e-01, 3.4507e-01],
[ 4.0585e-01, 4.9990e-01, -7.8072e-02, ..., -3.6138e-02,
3.7966e-01, 2.0279e-01],
...,
[-6.9440e-02, -6.4434e-02, -8.5714e-02, ..., 5.5404e-01,
4.8774e-01, 6.1422e-01],
[ 4.0948e-01, 1.6974e-01, 3.8686e-01, ..., 4.5283e-01,
3.5058e-01, -1.0249e-03],
[-1.1901e-01, -1.0644e-01, -2.3372e-01, ..., 2.5683e-02,
-1.1525e-02, 7.9798e-02]]],
...,
[[[-1.1181e-01, -2.3980e-01, -1.2242e-01, ..., -7.5282e-02,
-3.8972e-03, 3.2171e-01],
[-9.0789e-02, -6.8418e-02, 2.6450e-01, ..., 9.4353e-01,
5.2219e-01, -2.7978e-02],
[-1.0445e-01, -1.3948e-01, -1.9482e-01, ..., -4.4340e-01,
-3.2342e-01, -2.8784e-01],
...,
[ 2.6013e-01, 5.5673e-01, 3.4526e-01, ..., -6.6979e-02,
1.4295e-01, 1.2825e-01],
[ 2.1219e-01, 1.6811e-01, 3.4648e-01, ..., -1.7945e-01,
-5.7473e-02, -1.0353e-01],
[ 1.5309e-01, 3.6817e-01, 8.1527e-01, ..., -3.1560e-01,
-2.1173e-01, -1.5874e-01]],
[[-1.2270e-01, -2.3678e-01, -2.6664e-01, ..., -2.1526e-01,
-1.0538e-01, -6.1172e-03],
[-1.0818e-01, -1.4680e-01, -3.5169e-02, ..., 8.5371e-01,
9.0588e-01, 4.2803e-01],
[-6.3214e-02, -1.6061e-01, -9.8925e-02, ..., 1.2289e-01,
-8.6168e-02, -1.4889e-01],
...,
[ 4.4536e-01, 7.0577e-01, 6.2773e-01, ..., -3.8386e-02,
1.9970e-01, 5.3314e-01],
[ 3.8489e-02, -1.4918e-02, -8.7967e-02, ..., 4.6038e-01,
4.1668e-01, 3.7444e-01],
[-4.0742e-02, -1.4857e-02, 7.6456e-02, ..., -8.1258e-03,
-5.5882e-02, -2.3799e-02]],
[[ 8.9049e-02, 1.4258e-01, 3.8292e-01, ..., 1.6133e-01,
1.4335e-01, 1.8039e-01],
[-5.2064e-02, -1.0456e-01, -1.9761e-01, ..., -3.3952e-01,
-2.6607e-01, -1.7296e-01],
[-1.4956e-02, -9.2497e-02, -2.9282e-01, ..., -1.7801e-02,
-1.0827e-01, -8.1224e-02],
...,
[-2.5800e-02, 8.0382e-02, 6.1825e-01, ..., -3.8199e-01,
-3.4327e-01, -1.8147e-01],
[-7.3410e-03, 1.9852e-01, 9.0659e-01, ..., -4.1740e-01,
-3.9565e-01, -2.5667e-01],
[-3.5025e-02, -4.8974e-02, 2.1491e-01, ..., -1.5206e-01,
-1.6684e-01, -1.3646e-01]],
...,
[[ 1.0104e-02, -1.0961e-01, -5.9199e-02, ..., -6.9714e-02,
-7.6551e-02, 7.1423e-02],
[-1.3282e-01, -1.5589e-01, -3.8276e-02, ..., 4.1338e-01,
6.7665e-01, 3.6271e-01],
[ 1.4272e-01, 1.7284e-01, 1.2092e+00, ..., 9.3071e-01,
4.6994e-01, 2.0182e-01],
...,
[-3.4049e-02, -8.2063e-03, -2.9202e-01, ..., 9.4046e-01,
8.4497e-01, 9.4680e-01],
[ 2.6419e-02, -5.1124e-02, -3.0245e-01, ..., 1.2244e+00,
1.1257e+00, 8.6354e-01],
[-1.0748e-01, -1.5871e-01, -3.2379e-01, ..., 9.6839e-01,
6.4024e-01, 4.8128e-01]],
[[-1.4607e-01, -1.7049e-01, -1.4007e-01, ..., -1.1310e-01,
1.8378e-01, 1.6500e-01],
[-2.1871e-01, -2.9596e-01, -3.2113e-01, ..., 8.1939e-02,
3.1708e-02, -4.5414e-02],
[-1.0030e-01, -1.8369e-01, -1.4575e-01, ..., 1.8224e+00,
8.1758e-01, 6.1068e-01],
...,
[ 6.0960e-01, 7.4829e-01, 3.7702e-01, ..., -1.8992e-01,
1.4258e-01, 2.0653e-01],
[ 2.2604e-01, 4.1783e-01, 1.5157e-01, ..., 8.4110e-02,
2.3727e-01, 3.3342e-01],
[-3.8223e-02, 8.7143e-02, -2.1989e-02, ..., 3.9344e-01,
1.7073e-01, 2.8971e-01]],
[[ 7.9271e-02, 5.7074e-01, -2.1520e-02, ..., -1.1682e-01,
-1.7579e-01, -2.7374e-01],
[ 8.3261e-01, 1.0145e+00, 6.4379e-02, ..., -2.5812e-01,
-4.6593e-01, -1.3197e-01],
[ 7.9696e-01, 5.2228e-01, -1.6356e-01, ..., -9.7610e-02,
2.4922e-01, 7.9236e-01],
...,
[-2.9078e-01, -2.4341e-01, 6.1101e-01, ..., -1.4474e-01,
-5.7758e-01, -4.4565e-01],
[-2.4909e-01, -1.3604e-01, 2.8487e-01, ..., -1.3682e-01,
-4.3985e-01, -2.3561e-01],
[-5.0680e-02, 5.5990e-02, 2.8167e-01, ..., -1.4398e-01,
-1.8487e-01, -8.6340e-02]]],
[[[-2.9001e-02, -8.0461e-02, -2.2254e-03, ..., 3.4950e-01,
5.7094e-02, 1.2278e-01],
[ 6.5166e-03, 2.1499e-01, 3.9006e-01, ..., -4.7470e-02,
1.0600e-01, -1.4044e-02],
[-1.5046e-01, -1.8937e-01, -3.1187e-01, ..., -3.9665e-01,
-2.4484e-01, -9.6640e-02],
...,
[-7.1152e-03, 4.8723e-03, 2.0440e-01, ..., -4.2842e-01,
-3.7245e-01, -1.6724e-01],
[-7.7690e-02, -8.0645e-02, -3.2737e-02, ..., 7.0092e-01,
4.5857e-01, 4.7120e-02],
[-4.4350e-02, -6.3200e-02, -7.5735e-02, ..., -4.2281e-02,
-6.1911e-02, -5.9531e-02]],
[[-1.0874e-01, -1.8132e-01, -2.3345e-01, ..., -2.0851e-01,
-7.4565e-02, -4.1284e-02],
[-4.0450e-02, 8.0014e-02, 2.1230e-01, ..., 1.1867e+00,
8.4320e-01, 2.3790e-01],
[ 2.9826e-01, 5.6713e-01, 5.7998e-01, ..., -7.3589e-02,
-1.1340e-01, -7.4309e-02],
...,
[-5.2319e-02, -4.7811e-02, -4.5951e-02, ..., -2.6250e-01,
-2.1904e-01, -9.8821e-02],
[ 1.2224e-02, 1.4982e-01, 1.3104e-01, ..., 3.2171e-01,
1.0444e-01, 1.2559e-01],
[ 4.7305e-02, 9.6597e-02, 2.5040e-02, ..., -5.0359e-02,
-1.3494e-03, 7.0893e-02]],
[[ 1.6767e-01, 2.7971e-01, 4.6564e-01, ..., 2.8015e-01,
1.0774e-01, 1.1769e-01],
[-1.1139e-01, -2.2390e-01, -2.6545e-01, ..., -4.3858e-01,
-2.4102e-01, -4.8400e-02],
[ 5.1649e-03, -8.5191e-03, -1.2295e-01, ..., 8.0194e-01,
4.0112e-01, -5.4734e-02],
...,
[ 1.2798e-02, -5.9980e-02, 4.2741e-02, ..., 1.1245e+00,
5.3877e-01, -5.4408e-02],
[-1.2384e-02, -9.3497e-03, -1.2220e-02, ..., -9.9615e-02,
-2.4048e-02, 9.7492e-03],
[-2.8891e-02, 1.5380e-03, -7.4589e-03, ..., -3.1683e-03,
-3.5122e-02, -5.1562e-03]],
...,
[[ 1.0419e-01, -3.7381e-02, 1.1611e-01, ..., 3.1713e-01,
2.4542e-02, 2.7387e-01],
[-1.1054e-01, -1.0922e-01, -6.1246e-02, ..., -2.5663e-01,
-3.7523e-02, -8.6062e-02],
[-6.2047e-02, 1.1356e-01, 1.7014e-01, ..., 5.3426e-01,
2.4899e-01, -1.7628e-03],
...,
[-4.8685e-02, -1.0132e-01, -1.5885e-01, ..., 9.7035e-01,
4.5634e-01, 1.8380e-01],
[-5.9409e-02, -6.6526e-02, -8.1364e-02, ..., -1.7169e-01,
-1.5398e-01, -8.7368e-02],
[-1.3775e-02, 8.7838e-03, 7.3347e-02, ..., -3.3782e-02,
-2.5098e-02, -1.5676e-02]],
[[-2.5643e-02, -3.8952e-03, 1.1231e-01, ..., 4.4662e-01,
4.0375e-01, 9.5903e-02],
[-5.6202e-02, 2.6972e-02, 9.8674e-02, ..., 2.6287e-01,
5.7475e-01, 2.9372e-01],
[-2.0101e-03, -7.7510e-02, -5.1859e-02, ..., -2.2054e-01,
-1.6271e-01, 1.7489e-03],
...,
[-4.3067e-02, 1.4013e-01, 1.3372e-01, ..., -4.0701e-01,
-3.2545e-01, -8.4230e-02],
[ 2.6859e-02, 1.0184e-01, 2.8891e-01, ..., 5.7155e-01,
1.4783e-01, -1.1396e-02],
[ 1.0030e-01, 4.3159e-02, 5.3251e-02, ..., -7.9844e-02,
-1.5021e-02, -1.8097e-02]],
[[-1.4595e-02, 4.1319e-02, -1.0434e-01, ..., -2.9386e-01,
-2.9393e-01, -1.7882e-01],
[ 1.5605e-01, -9.8099e-02, -1.8437e-01, ..., -7.0900e-02,
-1.0099e-01, 2.3862e-02],
[-8.2990e-02, -1.8384e-02, 3.9302e-02, ..., -5.9204e-02,
2.0589e-01, -1.4739e-01],
...,
[ 9.2245e-03, 6.3241e-02, 1.8415e-01, ..., -1.0739e-01,
1.5667e-01, -2.0955e-01],
[-5.8324e-02, -5.9889e-02, -5.9691e-02, ..., 1.1105e-01,
-2.6855e-02, 2.6966e-01],
[-9.0968e-02, -1.1653e-01, -1.3540e-01, ..., 2.7048e-01,
7.0940e-02, -4.9028e-02]]],
[[[-7.2283e-02, 5.3242e-06, -1.2537e-01, ..., -3.6909e-02,
-4.0432e-02, -1.2870e-02],
[ 2.9738e-02, -9.0208e-02, -1.4939e-01, ..., -1.1575e-02,
-3.9611e-02, -1.2851e-01],
[-1.1368e-03, 8.5878e-02, 1.1499e+00, ..., -9.8711e-02,
-1.6185e-01, -3.1904e-02],
...,
[-3.7651e-02, 2.6948e-01, 6.2389e-01, ..., -4.9798e-01,
-5.0601e-01, -1.6457e-01],
[-3.7075e-02, -7.4441e-02, -6.7169e-02, ..., 1.1513e+00,
1.1047e+00, 4.2056e-01],
[-1.3934e-02, 2.3110e-01, 6.1425e-01, ..., -2.9988e-01,
-2.7497e-01, -1.5535e-01]],
[[ 6.4189e-01, 1.0841e+00, 1.3602e+00, ..., -1.9836e-01,
-1.2433e-01, -5.1696e-02],
[ 1.8029e-01, -3.3437e-02, -1.0972e-01, ..., 1.0138e+00,
5.3192e-01, 2.5186e-01],
[-2.0570e-01, -3.2785e-01, -3.4307e-01, ..., -3.4972e-01,
-1.7695e-01, -5.9314e-02],
...,
[ 4.6911e-01, 7.8272e-01, 1.2014e+00, ..., -6.5946e-01,
-4.1826e-01, -1.9988e-01],
[ 2.9877e-01, 1.8459e-01, 5.8291e-02, ..., 6.7500e-01,
6.4126e-01, 2.3666e-01],
[ 7.2424e-02, 1.4965e-01, 2.7101e-01, ..., -1.3135e-01,
-7.8532e-02, -5.5658e-02]],
[[-1.2613e-01, -1.8205e-01, -2.7218e-01, ..., 7.2364e-01,
2.9884e-01, -3.0641e-04],
[ 8.1657e-03, 2.2355e-01, 5.2882e-01, ..., -3.8305e-01,
-1.7700e-01, -9.3458e-02],
[ 3.8487e-01, 4.5329e-01, 1.0837e+00, ..., 7.5589e-01,
3.6727e-01, 2.6407e-01],
...,
[-7.2355e-02, -1.6600e-01, -7.3874e-02, ..., 1.5793e+00,
7.3415e-01, -2.0445e-02],
[ 1.2154e-02, 2.5559e-01, 5.1223e-01, ..., -2.2214e-01,
-1.3171e-01, 3.5690e-03],
[-3.8189e-02, -7.5509e-02, -1.2035e-02, ..., -7.7224e-03,
-2.7375e-02, -8.4515e-02]],
...,
[[-1.3251e-01, -1.8230e-02, -1.4040e-01, ..., 1.7689e-01,
-2.7634e-02, 2.1177e-02],
[ 3.9579e-01, 3.7851e-01, 2.2339e-01, ..., -1.4981e-01,
-3.4347e-02, -1.0266e-01],
[-3.3756e-04, -2.0567e-01, -2.4259e-01, ..., 6.6921e-01,
8.1090e-02, 3.9280e-01],
...,
[-1.8066e-01, -2.0184e-01, -4.9982e-01, ..., 2.2452e+00,
1.0433e+00, 9.4168e-01],
[ 2.3412e-01, 1.5095e-01, -2.0737e-02, ..., -3.7328e-01,
-1.7414e-01, -1.1398e-01],
[-7.4851e-02, -1.2918e-01, -2.4440e-01, ..., 7.8000e-01,
5.6403e-01, 2.9963e-01]],
[[ 1.6167e-01, -1.5616e-02, -8.3539e-02, ..., 1.6285e-01,
3.6807e-02, 1.2664e-01],
[ 1.7148e-01, 5.5283e-02, -8.5373e-02, ..., -8.5442e-02,
-1.0091e-01, -6.3013e-02],
[-4.1532e-02, 5.2311e-01, 9.4856e-01, ..., -1.0042e-01,
-1.3674e-02, -5.0829e-02],
...,
[-8.3947e-03, 1.7153e-01, 3.2893e-02, ..., -3.2762e-01,
-2.4920e-01, -6.0408e-02],
[ 3.8477e-01, 3.5678e-01, 2.6632e-01, ..., 1.3007e+00,
1.0828e+00, 4.5551e-01],
[ 1.4014e-01, 4.3396e-01, 4.7056e-01, ..., -2.0901e-01,
-1.3785e-01, -2.1573e-02]],
[[-9.5441e-02, -9.8721e-02, 4.1590e-01, ..., 5.8188e-02,
1.1586e-01, -7.8749e-02],
[-1.8588e-01, -1.4471e-01, -9.8197e-02, ..., 1.2607e-01,
1.0127e-01, 3.5296e-01],
[ 3.5947e-01, 4.9455e-01, 6.6820e-02, ..., -2.6978e-01,
-2.3982e-01, -2.5859e-01],
...,
[-8.1056e-03, 3.2226e-01, 1.3108e+00, ..., -5.7864e-01,
-3.4547e-01, -4.6089e-01],
[-1.4639e-01, -1.4384e-01, -7.4725e-02, ..., 3.4417e-01,
-6.9169e-02, 2.8609e-01],
[ 2.0396e-02, -2.0405e-02, 2.5683e-01, ..., -2.4720e-01,
-1.3365e-01, -1.8511e-01]]]], device='cuda:0')
And that’s it. We’re done with our CSI human reconstruction tutorials. Thanks for reading.
Total running time of the script: (0 minutes 50.698 seconds)