Shortcuts

CSI human reconstruction Tutorial

# !pip install pysensing

In this tutorial, we will be implementing codes for CSI human reconstruction task

import sys
sys.path.append('../..')
import pysensing.csi.dataset.get_dataloader as get_dataloader
import pysensing.csi.model.get_model as get_model
import pysensing.csi.inference.predict as predict
import pysensing.csi.inference.train as train
import pysensing.csi.inference.embedding as embedding
import torch

Load the data

CSI reconstruction dataset:

# HandFi
# CSI size : 6, 20, 114
# image : 144, 144
# joints2d :  2, 42
# joints3d : 2, 21
# train number : 3600
# test number : 400

train_loader, test_loader = get_dataloader.load_recon_dataset('HandFi', batch_size=32, return_train=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for data in train_loader:
    ((joints,image), csi) = data
    joint = joints[:,:,0:21].to(device,dtype=torch.float)
    img=image.to(device,dtype=torch.float)
    csi=csi.to(device,dtype=torch.float)
    joint2d = joint[:,0:2,:]
    joint2d = joint2d.view(-1,42)
    joint3d = joint[:,2,:]

    print('data:', csi)
    print('img:', img)
    print('joint:', joint)
    break
/data1/msc/zyj/yunjiao_csi/1028/yunjiao_csi/tutorials/csi_source/data/csi
using dataset: HandFi
data: tensor([[[[ -32.7930,  -38.9111,  -40.3795,  ...,  -48.4554,  -45.7634,
            -42.0926],
          [ -52.3264,  -56.3516,  -54.7415,  ...,  -70.0369,  -63.5968,
            -61.1817],
          [ -40.2269,  -48.7751,  -48.2723,  ...,   52.2950,   48.2723,
             45.7581],
          ...,
          [ -30.3462,  -28.4496,  -27.5012,  ...,  -71.1239,  -63.5373,
            -61.6407],
          [ -31.2848,  -35.6300,  -35.2823,  ...,   42.5821,   39.9751,
             36.4990],
          [ -27.3576,  -33.3301,  -32.1741,  ...,   40.6511,   36.6053,
             35.2567]],

         [[ -11.5020,  -12.7257,   -9.5442,  ...,  -11.5020,   -9.7890,
             -6.3628],
          [   8.0502,   12.0753,   16.9055,  ...,  -12.8804,   -6.4402,
            -10.4653],
          [ -29.1645,  -30.6730,  -28.6617,  ...,  -42.7411,  -40.7298,
            -38.2156],
          ...,
          [  42.6743,   50.2609,   55.0025,  ...,   -2.8450,    0.0000,
             -0.9483],
          [   5.7356,    7.4736,    7.1260,  ...,   -6.2570,   -6.6046,
             -3.4761],
          [ -14.2568,  -14.2568,  -15.9907,  ...,  -15.2201,  -16.1834,
            -15.4128]],

         [[  -8.8741,  -10.4875,   -5.1093,  ...,  -36.3030,  -30.9248,
            -26.6222],
          [  -6.2461,   -3.1231,   -1.0410,  ...,  -64.0227,  -51.5304,
            -54.1330],
          [ -52.8793,  -58.3875,  -55.6334,  ...,   -0.5508,   -1.6525,
             -3.8558],
          ...,
          [  43.7343,   49.8239,   59.7887,  ...,  -60.3423,  -49.8239,
            -45.9487],
          [   9.0769,   13.7834,   14.7920,  ...,   11.7663,   10.4216,
             13.1111],
          [ -15.6624,  -16.7303,  -13.1707,  ...,    4.2716,   -0.3560,
             -0.3560]],

         [[  48.4040,   50.5553,   55.3957,  ...,   47.5972,   45.9838,
             41.9501],
          [  71.3098,   77.5559,   75.9944,  ...,   52.5715,   52.0509,
             49.9689],
          [  40.2103,   47.3710,   49.0235,  ...,  -80.4206,  -77.1156,
            -74.3615],
          ...,
          [  55.3599,   54.2527,   55.3599,  ...,   60.8959,   58.1279,
             57.0207],
          [  41.6865,   47.0654,   46.0568,  ...,  -50.0910,  -50.4272,
            -45.0483],
          [  39.5120,   45.5633,   44.1395,  ...,  -53.7505,  -48.4111,
            -49.4789]],

         [[  14.7038,   17.5045,   16.8043,  ...,   59.9821,   52.2801,
             49.0126],
          [  23.4112,   23.4112,   28.0016,  ...,   85.8411,   79.8735,
             73.4469],
          [  19.3270,   23.3722,   27.4174,  ...,  -45.3961,  -40.4519,
            -35.9573],
          ...,
          [   8.1417,   10.5842,    8.1417,  ...,   85.8948,   79.3814,
             72.0539],
          [   9.0292,   12.8157,   12.2332,  ...,  -36.4082,  -34.6606,
            -32.9130],
          [  12.9031,   16.5897,   19.6618,  ...,  -28.8783,  -23.9628,
            -22.4268]],

         [[  -7.4686,   -6.3016,   -3.7343,  ...,  -21.9390,  -20.3052,
            -23.3393],
          [  -9.1809,  -12.3942,   -7.8037,  ...,   -8.2628,  -14.2303,
            -11.4761],
          [   8.0904,   11.2367,    8.5399,  ...,   71.9146,   70.5662,
             67.8694],
          ...,
          [ -24.4251,  -28.4959,  -26.8676,  ...,  -25.2392,  -25.2392,
            -25.2392],
          [ -10.4856,  -12.2332,  -15.4371,  ...,   35.8257,   36.1169,
             33.7868],
          [  -1.5361,   -2.1505,   -3.0722,  ...,   44.8535,   42.7030,
             40.5525]]],


        [[[ -42.5309,  -49.6194,  -49.0741,  ...,  -46.3478,  -44.7120,
            -43.0762],
          [  39.3843,   43.5416,   43.9792,  ...,   47.2612,   43.9792,
             39.3843],
          [  40.7355,   43.1317,   45.5279,  ...,   11.9810,   11.9810,
             10.7829],
          ...,
          [ -54.7987,  -63.9318,  -65.4540,  ...,    4.5666,   -1.5222,
              0.0000],
          [ -39.2957,  -43.6351,  -43.8761,  ...,   40.0189,   37.3670,
             35.1973],
          [  -5.2770,  -10.5540,  -14.6132,  ...,   45.4632,   43.8395,
             39.3744]],

         [[  11.4506,   15.8128,   17.9938,  ...,  -25.0823,  -20.7202,
            -17.9938],
          [ -10.9401,  -16.4101,  -19.6922,  ...,    1.0940,   -2.4068,
             -1.0940],
          [ -52.7165,  -62.3013,  -59.9051,  ...,   76.6786,   67.0937,
             63.4994],
          ...,
          [ -14.7145,  -10.1479,  -14.2071,  ...,  -67.4835,  -62.4096,
            -55.8135],
          [ -13.9825,  -16.6344,  -15.6700,  ...,  -25.5542,  -21.9381,
            -20.9738],
          [ -45.4632,  -49.9283,  -51.9579,  ...,  -25.5730,  -24.7612,
            -22.3257]],

         [[  -5.4406,   -4.0805,   -1.0881,  ...,  -61.4789,  -53.8621,
            -50.5977],
          [ -11.9511,  -17.1606,  -20.2249,  ...,   11.6447,    4.2901,
              5.2095],
          [ -30.0826,  -35.7585,  -38.0289,  ...,  105.5728,   96.4913,
             90.8153],
          ...,
          [ -33.7500,  -38.3523,  -36.3069,  ...,  -84.3751,  -81.3069,
            -76.1933],
          [ -12.1761,  -13.8365,  -15.4969,  ...,  -25.7359,  -25.7359,
            -25.7359],
          [ -44.9476,  -51.5253,  -54.2660,  ...,   -2.1926,   -4.6592,
             -7.3999]],

         [[  45.9732,   51.1418,   52.2299,  ...,   50.5977,   47.6054,
             44.3410],
          [ -38.9177,  -44.7400,  -45.3529,  ...,  -69.5616,  -62.5135,
            -59.7555],
          [ -58.4624,  -62.4356,  -65.2735,  ...,   28.9474,   24.4066,
             21.5686],
          ...,
          [  48.5796,   54.2046,   56.7614,  ...,  -48.0682,  -40.9091,
            -36.3069],
          [  40.4026,   45.9372,   44.5536,  ...,  -61.1574,  -60.6039,
            -57.0064],
          [ -10.1406,   -6.3036,   -7.1258,  ...,  -75.3694,  -72.3546,
            -68.2436]],

         [[  30.7796,   32.5726,   33.7679,  ...,   47.2153,   45.1235,
             39.4457],
          [ -26.0058,  -26.9458,  -25.0658,  ...,  -32.2722,  -25.6925,
            -25.0658],
          [ -31.4767,  -33.0247,  -39.2169,  ...,  -38.1849,  -35.0888,
            -32.5087],
          ...,
          [  31.5271,   42.0361,   47.8745,  ...,   21.6019,   22.1857,
             21.0181],
          [  28.4140,   31.6697,   34.3335,  ...,   -9.1753,  -11.5432,
             -8.5834],
          [  -1.8744,    1.5620,    4.0612,  ...,  -30.9273,  -26.2414,
            -20.3058]],

         [[  -2.6895,   -6.8731,   -3.8848,  ...,   -0.8965,   -1.7930,
             -3.2871],
          [  14.7262,   15.9795,   19.4260,  ...,   27.2591,   28.1991,
             22.5592],
          [  32.5087,   38.1849,   36.6368,  ...,  -54.1812,  -54.1812,
            -46.9571],
          ...,
          [  12.2605,   14.0120,   15.1797,  ...,   56.6320,   48.4583,
             45.5391],
          [   4.7357,    2.0719,    3.5517,  ...,   42.9169,   36.9973,
             34.0375],
          [  32.1769,   36.8629,   37.1753,  ...,   35.6133,   37.8001,
             31.8645]]],


        [[[  -3.9861,   -1.1389,    0.0000,  ...,  -51.2495,  -48.9718,
            -44.9857],
          [  25.5935,   26.0810,   25.3498,  ...,    9.2624,   10.4812,
              8.5312],
          [ -23.2937,  -30.8484,  -33.9962,  ...,   69.8810,   62.9559,
             61.0672],
          ...,
          [  48.3282,   53.0167,   55.1807,  ...,   33.1805,   32.8199,
             31.3773],
          [  28.6661,   34.0748,   37.8609,  ...,  -12.9809,  -10.8174,
             -9.7357],
          [ -27.7503,  -31.5518,  -36.4936,  ...,   53.9801,   50.5588,
             46.7574]],

         [[  48.4023,   53.5273,   55.2356,  ...,  -14.8054,  -14.8054,
            -14.2360],
          [ -34.3685,  -39.9747,  -43.3872,  ...,   47.5309,   44.1184,
             41.6809],
          [ -54.7716,  -61.0672,  -62.9559,  ...,   10.0729,    7.5547,
              8.8138],
          ...,
          [  11.1804,    7.9345,    8.2951,  ...,   45.8036,   37.5084,
             35.3445],
          [  38.9426,   40.0243,   41.1061,  ...,   54.6278,   49.2191,
             45.4330],
          [ -35.7333,  -41.8156,  -39.9149,  ...,   -2.6610,   -2.6610,
             -2.6610]],

         [[  65.3350,   68.0154,   71.7010,  ...,  -63.3247,  -52.9381,
            -46.5721],
          [ -52.9388,  -61.4945,  -66.3071,  ...,   80.7450,   67.9114,
             62.0293],
          [ -80.8251,  -80.1985,  -93.9826,  ...,   70.8002,   55.1365,
             53.2568],
          ...,
          [  27.1127,   22.5939,   22.0291,  ...,   84.1624,   79.0787,
             71.7357],
          [  57.4664,   61.8352,   57.4664,  ...,   62.1712,   62.5073,
             57.8024],
          [ -43.5171,  -34.2429,  -37.0965,  ...,    2.1402,   -2.8536,
             -7.1339]],

         [[  13.7371,   11.7268,    7.7062,  ...,   58.6340,   61.9845,
             53.6082],
          [ -21.9241,  -18.7157,  -20.3199,  ...,    3.7431,    1.0695,
             -3.2084],
          [  25.6886,   31.9541,   26.3151,  ...,  -85.8375,  -77.6923,
            -78.9454],
          ...,
          [ -57.6145,  -63.8278,  -75.1248,  ...,  -11.8618,  -19.2048,
            -15.2509],
          [ -29.2373,  -36.6306,  -38.3109,  ...,   56.7943,   50.4091,
             44.3600],
          [  41.3769,   52.0778,   52.7912,  ...,  -71.3395,  -70.6261,
            -63.4921]],

         [[  -1.5530,    0.0000,    0.9318,  ...,  -53.7336,  -50.3171,
            -46.5899],
          [   7.2149,    4.9758,    4.2294,  ...,   30.1034,   29.8547,
             27.6156],
          [ -24.4503,  -26.6334,  -29.2531,  ...,   67.6750,   60.6892,
             58.5062],
          ...,
          [  39.7358,   41.7351,   45.2339,  ...,   36.9868,   39.9857,
             36.4870],
          [  22.5564,   30.2811,   33.9890,  ...,   -2.1629,    0.0000,
              1.5450],
          [ -34.4755,  -37.2963,  -42.3109,  ...,   43.5645,   37.6097,
             33.2219]],

         [[  40.9991,   45.0369,   46.9005,  ...,   -7.7650,   -4.3484,
             -4.0378],
          [ -35.5768,  -41.5477,  -43.2893,  ...,   37.5671,   33.3377,
             31.8450],
          [ -45.8444,  -50.2105,  -51.0837,  ...,   -5.2394,   -7.4224,
             -6.1126],
          ...,
          [   7.7472,    4.7483,    3.9986,  ...,   34.9875,   33.2382,
             27.7401],
          [  31.5171,   34.9160,   34.2980,  ...,   52.8375,   50.6746,
             47.2757],
          [ -16.6109,  -18.1780,  -17.8646,  ...,  -29.1475,  -27.8938,
            -28.5207]]],


        ...,


        [[[ -48.2071,  -55.4837,  -56.3932,  ...,  -13.6435,  -13.6435,
            -12.7340],
          [ -52.5296,  -64.3911,  -65.5208,  ...,   68.3449,   65.5208,
             59.8724],
          [   7.8825,    8.3462,    3.9412,  ...,   30.8344,   30.3707,
             28.0523],
          ...,
          [  -4.0329,    2.4197,    4.8395,  ...,  -67.7527,  -61.3000,
            -58.8803],
          [   4.0634,    7.4495,   10.8356,  ...,  -46.0514,  -43.3425,
            -39.2791],
          [   2.1584,    4.3168,    8.6337,  ...,   -8.6337,   -9.7129,
             -9.7129]],

         [[ -48.2071,  -47.2976,  -49.1167,  ...,  -74.5846,  -69.1272,
            -64.5794],
          [ -20.3340,  -19.7692,  -22.0285,  ...,    7.3428,    4.5187,
              9.6022],
          [ -39.4124,  -44.7447,  -44.2810,  ...,   37.5577,   35.7030,
             34.0801],
          ...,
          [  56.4605,   62.9132,   64.5263,  ...,  -17.7447,  -14.5184,
            -13.7118],
          [  54.8553,   62.3048,   65.6910,  ...,  -54.1781,  -49.4375,
            -46.7286],
          [  59.3565,   64.7525,   65.8317,  ...,  -67.9902,  -64.7525,
            -60.4357]],

         [[ -31.8529,  -32.6492,  -31.8529,  ...,  -78.8359,  -78.8359,
            -68.4837],
          [  -2.5217,    0.8406,    1.6811,  ...,   -1.6811,   -5.0433,
             -0.8406],
          [ -45.3756,  -51.7747,  -49.1569,  ...,   36.6495,   34.3226,
             31.4139],
          ...,
          [  66.3063,   69.7961,   67.0043,  ...,  -12.5633,   -6.9796,
             -3.4898],
          [  57.7810,   66.5738,   70.3421,  ...,  -59.6652,  -50.2444,
            -43.3358],
          [  60.5831,   69.0365,   68.3321,  ...,  -74.6721,  -71.8543,
            -64.8098]],

         [[  70.0763,   75.6506,   79.6322,  ...,   27.0749,   23.0933,
             25.4823],
          [  63.8823,   72.2878,   73.1284,  ...,  -77.3312,  -70.6067,
            -71.4473],
          [   4.0722,    7.2717,    9.8896,  ...,  -41.8852,  -39.8491,
            -38.9765],
          ...,
          [ -21.6368,  -27.9185,  -25.8246,  ...,   76.7758,   71.8900,
             67.7023],
          [ -25.7502,  -30.7747,  -34.5430,  ...,   56.5249,   53.3846,
             50.2444],
          [ -21.1336,  -30.2915,  -30.2915,  ...,   21.1336,   22.5425,
             18.3158]],

         [[  23.7182,   32.5544,   34.4146,  ...,   24.6483,   24.1832,
             24.6483],
          [  33.4727,   42.2814,   39.9324,  ...,  -70.4689,  -62.2476,
            -62.8348],
          [   1.5144,    2.7763,    3.2811,  ...,  -45.1782,  -44.6735,
            -40.1304],
          ...,
          [   6.8301,    1.3660,    7.2855,  ...,   69.2119,   61.4711,
             57.8283],
          [   2.2891,    2.7469,    0.9156,  ...,   53.1067,   46.6973,
             42.1191],
          [   1.2312,    2.8728,    3.2831,  ...,   20.5197,   17.2365,
             13.5430]],

         [[  34.4146,   38.1351,   39.9954,  ...,   69.2943,   64.6437,
             57.2027],
          [  18.7917,   20.5534,   21.1407,  ...,    4.1107,    7.6341,
              2.3490],
          [  28.7727,   34.5778,   29.2775,  ...,  -20.9486,  -19.9390,
            -17.6675],
          ...,
          [ -41.8914,  -46.9001,  -45.5341,  ...,    7.7408,    3.1874,
              5.4641],
          [ -42.5769,  -44.4082,  -44.8660,  ...,   48.0707,   40.2878,
             38.4566],
          [ -40.6289,  -45.1432,  -46.3744,  ...,   66.8941,   61.5590,
             58.2758]]],


        [[[ -30.1405,  -36.7125,  -33.9930,  ...,  -16.0900,  -13.5972,
            -13.8238],
          [  31.3107,   28.8549,   29.4689,  ...,   58.3238,   53.4123,
             45.4312],
          [  12.5678,   16.8247,   15.0003,  ...,  -22.2978,  -18.6490,
            -17.0274],
          ...,
          [  -8.7890,  -10.8816,   -5.4408,  ...,  -63.1970,  -57.7562,
            -56.0821],
          [  33.4186,   36.1100,   37.2314,  ...,   57.6414,   53.3800,
             47.9971],
          [  31.1475,   35.2324,   29.8709,  ...,  -32.4240,  -29.6156,
            -28.0838]],

         [[ -21.7555,  -22.2088,  -23.1153,  ...,  -54.1622,  -49.6298,
            -46.2305],
          [ -42.3615,  -46.0451,  -54.0262,  ...,  -51.5705,  -52.7984,
            -49.1148],
          [  34.8656,   38.7170,   37.2981,  ...,   53.1092,   50.2713,
             46.6226],
          ...,
          [  45.6190,   51.0598,   55.6636,  ...,  -33.0633,  -30.9707,
            -25.1114],
          [ -15.0271,  -16.8214,  -23.9986,  ...,   -3.5886,   -4.7100,
             -5.3829],
          [ -26.5519,  -31.4028,  -32.1687,  ...,   54.3804,   52.5933,
             46.7212]],

         [[ -23.1707,  -21.9512,  -22.7642,  ...,  -65.0406,  -59.7560,
            -54.4715],
          [ -38.7183,  -45.0101,  -45.4940,  ...,    5.8077,   -0.4840,
             -1.4519],
          [  40.1074,   44.6479,   43.8911,  ...,   42.3776,   37.8372,
             38.2155],
          ...,
          [  50.8550,   59.7250,   60.3164,  ...,  -78.0565,  -72.7345,
            -66.8211],
          [ -24.7524,  -25.9501,  -30.7409,  ...,   29.9424,   24.3532,
             19.1631],
          [ -36.9102,  -40.5210,  -44.9342,  ...,   35.7066,   35.7066,
             31.2935]],

         [[  38.6178,   44.7154,   45.9349,  ...,  -13.4146,  -14.6341,
            -12.1951],
          [ -55.1736,  -54.2057,  -52.7537,  ...,  -89.5361,  -85.6643,
            -79.8565],
          [ -17.4051,  -24.2158,  -27.2428,  ...,   48.4316,   46.9181,
             44.6479],
          ...,
          [  26.0188,   27.2015,   26.0188,  ...,   21.2881,   23.6535,
             24.8362],
          [ -38.7255,  -43.1171,  -41.9194,  ...,  -57.8887,  -57.0902,
            -50.7025],
          [ -33.7007,  -34.5031,  -29.6887,  ...,   61.7845,   59.7785,
             57.7726]],

         [[ -16.6859,  -15.5482,  -20.0989,  ...,  -33.7509,  -31.0964,
            -32.9925],
          [   3.7433,    2.8075,   -1.8716,  ...,   78.1413,   63.1681,
             61.7644],
          [  17.1091,   17.8219,   17.4655,  ...,  -12.1189,   -3.9208,
             -5.7030],
          ...,
          [   5.6243,    6.8295,    8.0347,  ...,  -75.5260,  -69.9017,
            -65.0809],
          [   2.6111,    4.1032,    4.1032,  ...,   67.8896,   61.9213,
             56.3260],
          [  -3.4724,   -2.3149,   -0.3858,  ...,  -27.3935,  -22.7636,
            -20.0629]],

         [[   5.3091,    9.1014,    0.3792,  ...,  -61.8135,  -52.7121,
            -48.1615],
          [ -24.3314,  -23.8635,  -21.5240,  ...,  -59.4248,  -59.4248,
            -52.4062],
          [   2.4951,    4.9901,    6.4159,  ...,   71.2877,   65.5847,
             59.5252],
          ...,
          [  18.8815,   24.1040,   19.2832,  ...,  -45.7977,  -42.9855,
            -35.3526],
          [ -16.7859,  -17.5319,  -18.6510,  ...,  -22.3812,  -20.8891,
            -19.0240],
          [ -19.2912,  -20.0629,  -18.9054,  ...,   74.8499,   67.9051,
             62.1177]]],


        [[[  46.8680,   56.0261,   57.6423,  ...,  -66.8004,  -59.2584,
            -58.7197],
          [  36.8100,   37.9749,   41.0035,  ...,   57.7777,   50.7885,
             47.5268],
          [ -62.3675,  -76.4504,  -72.4267,  ...,   22.1304,   16.0948,
             18.1067],
          ...,
          [  14.1170,   22.5872,   22.5872,  ...,  -79.0551,  -73.4083,
            -64.9382],
          [  47.5145,   50.0928,   52.3027,  ...,   61.1427,   55.6177,
             52.3027],
          [  47.1487,   52.6040,   57.6695,  ...,  -58.4489,  -52.2143,
            -48.7074]],

         [[  46.3293,   51.7164,   51.7164,  ...,   44.7132,   42.0196,
             37.1712],
          [ -27.4910,  -33.7813,  -41.4695,  ...,    8.8530,    6.7563,
              7.6882],
          [ -44.2608,  -46.2726,  -48.2845,  ...,  -90.5334,  -82.4860,
            -72.4267],
          ...,
          [  79.0551,   93.1721,   95.9955,  ...,  -62.1148,  -59.2914,
            -53.6446],
          [ -25.7830,  -31.3080,  -33.8863,  ...,   19.5214,   17.3115,
             14.3648],
          [  23.7692,   25.3278,   27.2761,  ...,  -31.5624,  -28.0555,
            -25.7175]],

         [[  47.9657,   47.9657,   48.8707,  ...,   32.5804,   39.8205,
             35.2955],
          [ -33.4600,  -40.1521,  -45.3893,  ...,   32.2962,   25.6042,
             20.9489],
          [ -44.3696,  -45.0127,  -40.5114,  ..., -100.3139,  -99.6709,
            -88.0962],
          ...,
          [  89.7694,  102.5936,  100.9906,  ..., -100.9906,  -92.9754,
            -81.7543],
          [  -6.2048,  -12.7974,  -11.2462,  ...,   62.0480,   51.9652,
             49.2506],
          [  43.2276,   44.4510,   45.6744,  ...,  -70.9585,  -63.6180,
            -60.7633]],

         [[ -54.3007,  -66.9709,  -65.1609,  ...,   95.0263,   84.1661,
             88.6912],
          [ -37.8244,  -37.5334,  -39.5701,  ...,  -66.0472,  -59.6462,
            -55.8637],
          [  71.3772,   85.5240,   81.0228,  ...,  -50.8000,  -46.2987,
            -43.0835],
          ...,
          [ -22.4424,  -27.2514,  -32.0605,  ...,   81.7543,   78.5482,
             68.9301],
          [ -56.6188,  -59.7212,  -62.4358,  ...,  -53.5164,  -50.8018,
            -46.5360],
          [ -38.3339,  -40.7807,  -44.0432,  ...,   41.1886,   39.5573,
             35.8871]],

         [[  31.6533,   33.8362,   38.7480,  ...,  -48.0256,  -40.3852,
            -37.1107],
          [  11.5907,    9.2726,    8.6930,  ...,   49.8401,   44.0447,
             42.3061],
          [ -34.1257,  -44.3635,  -49.4823,  ...,  -14.2191,  -15.3566,
            -19.3379],
          ...,
          [  17.7250,   15.5980,   16.3070,  ...,  -98.5508,  -90.0428,
            -83.6618],
          [  26.2991,   26.2991,   29.8706,  ...,   65.5855,   58.4425,
             53.8970],
          [  29.6106,   31.0551,   30.6940,  ...,  -64.9990,  -56.6935,
            -55.6102]],

         [[  24.5586,   26.1958,   26.7415,  ...,   66.5810,   60.5778,
             56.2118],
          [ -26.3689,  -31.5847,  -31.0052,  ...,  -28.9768,  -26.0791,
            -23.7610],
          [ -18.7692,  -22.7505,  -26.7318,  ...,  -90.4332,  -83.0393,
            -73.9391],
          ...,
          [  46.7939,   53.1749,   53.8839,  ...,  -29.0689,  -23.3969,
            -27.6509],
          [ -20.1302,  -22.4030,  -19.8055,  ...,   -3.2468,   -5.1949,
             -2.5974],
          [  12.2776,   14.0831,   10.8332,  ...,  -11.1943,   -7.5832,
             -5.4166]]]], device='cuda:0')
img: tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0')
joint: tensor([[[0.4065, 0.3842, 0.2791,  ..., 0.4321, 0.4197, 0.4292],
         [0.3706, 0.5283, 0.4555,  ..., 0.3087, 0.3569, 0.3985],
         [0.8637, 0.8774, 0.8543,  ..., 0.7560, 0.7454, 0.7595]],

        [[0.3400, 0.3053, 0.2496,  ..., 0.3744, 0.3568, 0.3587],
         [0.3861, 0.5380, 0.4331,  ..., 0.3324, 0.3786, 0.4183],
         [0.8486, 0.8448, 0.8010,  ..., 0.7469, 0.7428, 0.7609]],

        [[0.3652, 0.3266, 0.2775,  ..., 0.3989, 0.3775, 0.3696],
         [0.3761, 0.5271, 0.4217,  ..., 0.3347, 0.3765, 0.4194],
         [0.8042, 0.8070, 0.7568,  ..., 0.7023, 0.6862, 0.6876]],

        ...,

        [[0.3437, 0.3123, 0.2372,  ..., 0.4320, 0.4302, 0.4275],
         [0.4288, 0.5759, 0.4799,  ..., 0.3065, 0.2592, 0.2173],
         [0.8367, 0.8344, 0.8248,  ..., 0.8529, 0.8606, 0.8648]],

        [[0.3488, 0.3174, 0.2413,  ..., 0.4052, 0.3935, 0.3941],
         [0.3413, 0.4948, 0.3969,  ..., 0.3103, 0.3587, 0.3945],
         [0.7915, 0.8000, 0.7705,  ..., 0.6966, 0.6999, 0.7255]],

        [[0.3609, 0.3908, 0.2932,  ..., 0.3134, 0.2729, 0.2739],
         [0.3475, 0.4973, 0.4819,  ..., 0.2110, 0.2381, 0.2809],
         [0.8023, 0.7771, 0.8549,  ..., 0.7438, 0.7365, 0.7290]]],
       device='cuda:0')

Load the model

Model zoo: AutoEncoder

model = get_model.load_recon_model('HandFi', 'AutoEncoder')
print(model)
CSI_AutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(6, 3, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), groups=3, bias=False)
    (1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.3, inplace=True)
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): ConvTranspose2d(32, 16, kernel_size=(4, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): ConvTranspose2d(16, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): ReLU()
    (8): ConvTranspose2d(3, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): Tanh()
  )
  (fc): Sequential(
    (0): Linear(in_features=13680, out_features=12544, bias=True)
    (1): ReLU()
  )
  (dropout): Dropout2d(p=0.5, inplace=False)
  (fj): Sequential(
    (0): Linear(in_features=28160, out_features=2048, bias=True)
    (1): ReLU()
  )
  (joint): Joints(
    (joints): Sequential(
      (0): Linear(in_features=2048, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
      (3): Linear(in_features=1024, out_features=512, bias=True)
      (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Linear(in_features=512, out_features=64, bias=True)
      (7): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): LeakyReLU(negative_slope=0.01)
    )
    (linear1): Linear(in_features=64, out_features=42, bias=True)
    (linear2): Linear(in_features=64, out_features=21, bias=True)
  )
  (decmask): ResNet18Dec(
    (linear): Linear(in_features=2048, out_features=512, bias=True)
    (layer4): Sequential(
      (0): BasicBlockDec(
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlockDec(
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv1): ResizeConv2d(
          (conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): ResizeConv2d(
            (conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (layer3): Sequential(
      (0): BasicBlockDec(
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlockDec(
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv1): ResizeConv2d(
          (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): ResizeConv2d(
            (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (layer2): Sequential(
      (0): BasicBlockDec(
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlockDec(
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv1): ResizeConv2d(
          (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): ResizeConv2d(
            (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (layer1): Sequential(
      (0): BasicBlockDec(
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlockDec(
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (conv2): Conv2d(32, 16, kernel_size=(5, 5), stride=(1, 1))
    (conv1): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1))
    (conv3): Conv2d(16, 4, kernel_size=(5, 5), stride=(1, 1))
    (conv4): Conv2d(4, 2, kernel_size=(5, 5), stride=(1, 1))
    (conv5): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1))
    (score_fr): Conv2d(4096, 2, kernel_size=(1, 1), stride=(1, 1))
    (upscore): ConvTranspose2d(64, 64, kernel_size=(8, 8), stride=(4, 4), bias=False)
  )
  (incep): InceptionV3(
    (Conv2d_1a_3x3): ConvBN(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (Conv2d_2a_3x3): ConvBN(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (Conv2d_2b_3x3): ConvBN(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (Mixed_5b): InceptionA(
      (branch1x1): BasicConv2d(
        (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (branch5x5): Sequential(
        (0): ConvBN(
          (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): ConvBN(
          (0): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (branch3x3): Sequential(
        (0): ConvBN(
          (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): ConvBN(
          (0): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (2): ConvBN(
          (0): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (branchpool): Sequential(
        (0): AvgPool2d(kernel_size=1, stride=1, padding=0)
        (1): ConvBN(
          (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
    )
    (Mixed_5c): InceptionA(
      (branch1x1): BasicConv2d(
        (conv): Conv2d(176, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (branch5x5): Sequential(
        (0): ConvBN(
          (0): Conv2d(176, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): ConvBN(
          (0): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (branch3x3): Sequential(
        (0): ConvBN(
          (0): Conv2d(176, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): ConvBN(
          (0): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (2): ConvBN(
          (0): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (branchpool): Sequential(
        (0): AvgPool2d(kernel_size=1, stride=1, padding=0)
        (1): ConvBN(
          (0): Conv2d(176, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
    )
    (Mixed_5d): InceptionA(
      (branch1x1): BasicConv2d(
        (conv): Conv2d(160, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (branch5x5): Sequential(
        (0): ConvBN(
          (0): Conv2d(160, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): ConvBN(
          (0): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (branch3x3): Sequential(
        (0): ConvBN(
          (0): Conv2d(160, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): ConvBN(
          (0): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (2): ConvBN(
          (0): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (branchpool): Sequential(
        (0): AvgPool2d(kernel_size=1, stride=1, padding=0)
        (1): ConvBN(
          (0): Conv2d(160, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
    )
    (Mixed_6a): Conv2d(152, 312, kernel_size=(5, 6), stride=(2, 2), padding=(1, 1))
    (Mixed_6b): InceptionC(
      (branch1x1): BasicConv2d(
        (conv): Conv2d(312, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (branch7x7): Sequential(
        (0): BasicConv2d(
          (conv): Conv2d(312, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): BasicConv2d(
          (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): BasicConv2d(
          (conv): Conv2d(128, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (branch7x7stack): Sequential(
        (0): BasicConv2d(
          (conv): Conv2d(312, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): BasicConv2d(
          (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): BasicConv2d(
          (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): BasicConv2d(
          (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (4): BasicConv2d(
          (conv): Conv2d(128, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (branch_pool): Sequential(
        (0): AvgPool2d(kernel_size=3, stride=1, padding=1)
        (1): BasicConv2d(
          (conv): Conv2d(312, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (Mixed_6c): InceptionC(
      (branch1x1): BasicConv2d(
        (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (branch7x7): Sequential(
        (0): BasicConv2d(
          (conv): Conv2d(256, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): BasicConv2d(
          (conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
          (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): BasicConv2d(
          (conv): Conv2d(160, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (branch7x7stack): Sequential(
        (0): BasicConv2d(
          (conv): Conv2d(256, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): BasicConv2d(
          (conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
          (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): BasicConv2d(
          (conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
          (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): BasicConv2d(
          (conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
          (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (4): BasicConv2d(
          (conv): Conv2d(160, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (branch_pool): Sequential(
        (0): AvgPool2d(kernel_size=3, stride=1, padding=1)
        (1): BasicConv2d(
          (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (Mixed_6d): InceptionC(
      (branch1x1): BasicConv2d(
        (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (branch7x7): Sequential(
        (0): BasicConv2d(
          (conv): Conv2d(256, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): BasicConv2d(
          (conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
          (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): BasicConv2d(
          (conv): Conv2d(160, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (branch7x7stack): Sequential(
        (0): BasicConv2d(
          (conv): Conv2d(256, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): BasicConv2d(
          (conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
          (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): BasicConv2d(
          (conv): Conv2d(160, 160, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
          (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): BasicConv2d(
          (conv): Conv2d(160, 160, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)
          (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (4): BasicConv2d(
          (conv): Conv2d(160, 64, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (branch_pool): Sequential(
        (0): AvgPool2d(kernel_size=3, stride=1, padding=1)
        (1): BasicConv2d(
          (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
    (conv2): ConvBN(
      (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
)

Model train

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epoch_num = 1

train.recon_train(train_loader, model, epoch_num, optimizer, device)
Epoch:1, Loss:44.061765085
mPA: 0.949 | => IoU: 0.000 | => mpjpe: 0.439 | =>pck: 0.000

Model inference

model = get_model.load_pretrain(model, 'HandFi', 'AutoEncoder', device=device)
output = predict.recon_predict(csi, 'HandFi', model, device)
_, mask, twod, threed = output
print("mask:", mask.shape)
print("twod:", twod.shape)
print("threed:", threed.shape)
mask: torch.Size([32, 1, 114, 114])
twod: torch.Size([32, 42])
threed: torch.Size([32, 21])

Evaluate loss

IoUerr = train.IoU(img,mask)
mPAerr = train.mPA(img,mask)
mpjpe, pck = train.mpjpe_pck(joint2d,joint3d, twod, threed)

print(  f'mPA: {mPAerr:.3f} | => IoU: {IoUerr:.3f} | => mpjpe: {mpjpe:.3f} | =>pck: {pck:.3f}\n')
mPA: 0.998 | => IoU: 0.139 | => mpjpe: 0.011 | =>pck: 1.000

Generate embeddings

csi_embedding = embedding.recon_csi_embedding(csi, 'HandFi', model, device)
print('csi_embedding: ', csi_embedding)
csi_embedding:  tensor([[[[-7.5988e-02, -1.3345e-01, -7.9132e-02,  ..., -1.2570e-01,
           -3.0211e-02,  8.9611e-02],
          [-7.8937e-02, -7.9026e-02,  1.0359e-02,  ..., -1.9818e-02,
            1.3651e-01,  2.5493e-01],
          [-6.4701e-02, -1.0918e-01, -8.3367e-02,  ...,  3.1941e-01,
           -1.4069e-02, -1.2549e-01],
          ...,
          [-1.9620e-02, -2.9558e-02,  9.6607e-02,  ..., -3.6337e-01,
           -2.3918e-01, -2.0665e-01],
          [-2.5418e-02,  6.4009e-02,  2.6116e-02,  ..., -3.9826e-03,
            4.3250e-02,  1.4635e-01],
          [-7.6564e-02, -1.0956e-01, -1.9091e-01,  ...,  2.4052e-01,
            1.4318e-01,  8.8810e-02]],

         [[-6.8691e-02, -1.3890e-01, -1.5969e-01,  ..., -1.3222e-01,
           -8.6994e-02, -7.1656e-03],
          [-9.5087e-02, -1.7732e-01, -1.2714e-01,  ..., -6.6031e-02,
           -1.4130e-02,  3.1650e-01],
          [-1.2589e-01, -2.6198e-01, -2.1262e-01,  ...,  9.4285e-01,
            7.7951e-01,  6.6286e-01],
          ...,
          [-1.1734e-01, -2.0700e-01, -1.7124e-01,  ...,  1.7257e-01,
           -2.0734e-02, -2.0480e-02],
          [-3.7680e-02, -4.1951e-02,  8.4990e-02,  ..., -1.9991e-02,
           -2.5701e-02, -6.3073e-02],
          [ 7.5083e-03, -2.4141e-02, -2.7933e-02,  ...,  1.0905e-01,
            4.7774e-02, -1.5171e-02]],

         [[ 9.8348e-02,  1.9889e-01,  3.1992e-01,  ...,  2.6857e-01,
            1.6400e-01,  4.3267e-02],
          [ 5.6098e-02,  1.7994e-01,  5.7166e-02,  ..., -9.1950e-02,
           -6.3024e-02, -6.6400e-02],
          [-5.9416e-03, -2.6637e-02, -1.1798e-01,  ..., -6.1100e-01,
           -4.4403e-01, -2.4352e-01],
          ...,
          [ 4.9742e-02, -7.8758e-03, -2.8914e-03,  ..., -1.6821e-01,
           -2.0014e-01, -6.8226e-02],
          [-3.1435e-02, -8.7355e-02, -1.8312e-01,  ...,  5.9714e-01,
            4.9378e-01,  2.9055e-01],
          [ 4.0038e-02, -2.1485e-03, -7.1473e-02,  ...,  9.5176e-02,
            1.7183e-01,  6.3368e-02]],

         ...,

         [[-1.9152e-03, -5.3906e-02, -3.0001e-02,  ..., -8.9673e-02,
           -7.1675e-02, -3.1333e-02],
          [-5.0727e-02, -8.1216e-02, -2.1506e-02,  ..., -3.3926e-02,
           -2.7612e-03,  1.0606e-01],
          [ 3.5804e-02, -2.7429e-02,  5.7590e-01,  ...,  1.0959e+00,
            1.0089e+00,  7.1672e-01],
          ...,
          [-3.4329e-02, -8.6428e-02, -6.2210e-02,  ...,  6.7720e-01,
            3.4152e-01,  2.8824e-01],
          [-5.4396e-02, -3.1631e-02,  2.2842e-01,  ..., -2.0054e-01,
           -1.6767e-01, -1.4687e-01],
          [ 6.9414e-02,  1.4295e-01,  4.5608e-01,  ..., -1.3391e-01,
           -9.4052e-02, -8.8097e-02]],

         [[-5.4808e-02, -7.3786e-02, -4.5198e-02,  ..., -1.1224e-01,
           -8.8538e-03,  1.2416e-01],
          [-1.0270e-01, -1.5658e-01, -1.4829e-01,  ..., -4.0651e-01,
           -1.6581e-01, -2.9898e-02],
          [-1.3807e-01, -2.1471e-01, -2.1912e-01,  ..., -1.7778e-01,
           -1.2842e-01, -1.0364e-01],
          ...,
          [-1.2237e-01, -1.4201e-01, -1.7165e-01,  ...,  1.0059e+00,
            6.8678e-01,  3.9695e-01],
          [-5.4163e-02, -7.8612e-02, -9.3758e-02,  ...,  2.8166e-01,
            1.9588e-01,  9.0473e-02],
          [ 3.5573e-02, -1.4971e-02,  6.5952e-02,  ...,  9.5610e-02,
            2.6193e-02, -9.3885e-04]],

         [[-1.2038e-02,  2.6270e-01, -1.5770e-02,  ...,  1.4886e-01,
           -1.1121e-03, -1.4652e-01],
          [ 3.5575e-01,  6.8593e-01,  8.8052e-02,  ...,  7.7226e-02,
           -1.5856e-01, -3.2138e-01],
          [ 6.0042e-01,  6.7811e-01, -7.3340e-02,  ..., -2.6675e-01,
           -5.4458e-01, -2.0304e-01],
          ...,
          [ 2.9060e-01,  6.5125e-01,  2.5959e-01,  ...,  2.2907e-01,
            1.8849e-02,  3.0123e-01],
          [ 3.3442e-01,  2.3328e-01, -3.7865e-02,  ..., -5.5345e-02,
            2.5826e-01,  2.3314e-02],
          [-6.4319e-03, -1.4516e-02, -4.3971e-02,  ..., -1.0781e-01,
           -5.1529e-02, -2.8931e-02]]],


        [[[-1.1201e-01, -1.6031e-01, -1.3094e-01,  ..., -1.1048e-01,
           -2.1807e-02,  1.9078e-01],
          [ 2.8870e-01,  5.8584e-01,  5.4395e-01,  ...,  6.9207e-01,
            2.9979e-01, -5.5536e-02],
          [-4.2719e-02, -7.1615e-02, -1.1386e-01,  ..., -3.8829e-01,
           -2.9365e-01, -2.7612e-01],
          ...,
          [-6.9176e-02, -9.9855e-02, -1.3018e-01,  ..., -1.8537e-01,
           -1.2654e-01, -1.1904e-01],
          [-1.2052e-01, -1.0508e-01, -1.2950e-01,  ..., -1.0749e-01,
           -2.7437e-02,  4.0822e-02],
          [-8.3533e-02, -1.1825e-01, -1.8019e-01,  ...,  3.3185e-01,
            2.4575e-01,  1.3988e-01]],

         [[-2.4792e-02, -8.9631e-02, -9.7114e-02,  ..., -2.0839e-01,
           -1.2118e-01, -1.7923e-02],
          [-8.2709e-03,  3.7649e-02,  2.7741e-01,  ...,  9.4608e-01,
            8.0376e-01,  4.5409e-01],
          [-2.2271e-02, -1.1002e-02, -2.5261e-02,  ..., -7.4665e-03,
           -7.9154e-02, -8.4814e-02],
          ...,
          [-6.5269e-02, -1.4813e-01, -7.1302e-02,  ...,  6.2216e-01,
            1.7507e-01, -1.1795e-02],
          [-1.1235e-02, -7.7565e-03,  1.5451e-01,  ..., -1.1655e-02,
           -5.2153e-02, -7.0035e-02],
          [ 1.0802e-01,  1.0302e-01,  2.0105e-02,  ...,  1.5653e-01,
            5.5474e-02, -1.8608e-02]],

         [[ 7.8802e-02,  2.4965e-01,  3.1480e-01,  ...,  4.1662e-01,
            2.4005e-01,  1.0356e-01],
          [-8.9133e-02, -1.4291e-01, -2.2279e-01,  ..., -4.0896e-01,
           -2.5937e-01, -1.5681e-01],
          [-2.7849e-03, -1.2922e-01, -1.2820e-01,  ...,  1.7465e-01,
           -9.0056e-02, -9.5152e-02],
          ...,
          [-2.1662e-02, -1.1998e-01, -2.7085e-01,  ..., -2.0106e-01,
           -1.5820e-01, -8.5378e-02],
          [ 5.4356e-02, -5.8501e-02, -1.7795e-01,  ...,  5.3807e-01,
            3.7422e-01,  1.9366e-01],
          [-1.4098e-02,  1.0553e-02, -6.3063e-02,  ...,  1.3944e-01,
            2.4085e-01,  8.4258e-02]],

         ...,

         [[-2.8043e-02, -4.8503e-02, -6.5487e-02,  ..., -6.5838e-02,
           -7.8326e-02, -6.3266e-03],
          [-2.6876e-02,  1.5565e-04,  2.2422e-01,  ...,  2.1512e-01,
            4.5621e-01,  1.7036e-01],
          [ 5.3850e-02, -1.9424e-02,  2.1780e-01,  ...,  9.0915e-01,
            4.7735e-01,  1.8738e-01],
          ...,
          [ 5.9052e-02, -2.4386e-03,  7.7303e-01,  ...,  4.6570e-01,
            2.9815e-01,  1.7371e-01],
          [-3.9974e-02, -2.1020e-02,  4.1256e-01,  ..., -1.2656e-01,
           -1.3295e-01, -9.9367e-02],
          [ 7.0246e-02,  1.6595e-01,  4.4544e-01,  ..., -1.5969e-01,
           -1.2156e-01, -1.0217e-01]],

         [[-5.4430e-02, -1.0481e-01, -8.6851e-02,  ..., -1.0888e-01,
            4.6569e-02,  1.6418e-01],
          [ 6.2495e-02,  4.5346e-02, -8.1779e-02,  ..., -1.3841e-01,
           -1.1162e-01, -9.3298e-02],
          [ 3.3360e-02,  3.5860e-01,  6.7352e-01,  ...,  1.1239e+00,
            3.5015e-01,  3.1024e-01],
          ...,
          [-1.0170e-01, -1.4160e-01, -1.1348e-01,  ...,  8.1138e-01,
            4.3449e-01,  2.7265e-01],
          [-3.6892e-02, -4.0478e-02,  9.2279e-02,  ...,  6.6005e-01,
            4.1285e-01,  2.5957e-01],
          [ 1.7225e-01,  7.8906e-02,  2.1531e-01,  ...,  1.4322e-01,
            3.0350e-02,  2.1241e-03]],

         [[-1.7732e-02,  4.2032e-01,  2.8272e-01,  ...,  2.7503e-02,
           -4.2699e-02, -2.0023e-01],
          [ 1.2115e-01, -1.2649e-01, -1.9780e-01,  ..., -1.9010e-01,
           -3.5124e-01, -1.0214e-01],
          [-1.1262e-02, -1.1930e-01, -9.5240e-02,  ...,  3.6433e-01,
            4.9932e-01,  6.3529e-01],
          ...,
          [ 7.0015e-01,  5.2267e-01, -6.4452e-02,  ..., -2.0909e-01,
           -1.6590e-01,  5.3467e-02],
          [ 4.5025e-01,  3.3467e-01,  4.1305e-02,  ..., -1.5154e-02,
            2.3561e-01,  6.5215e-02],
          [-5.5921e-02, -1.0450e-01, -1.3866e-01,  ..., -1.3112e-01,
           -5.6114e-02, -2.3075e-02]]],


        [[[-5.7434e-02, -1.4914e-02, -8.9773e-02,  ...,  1.8952e-01,
            8.1011e-02,  2.0715e-01],
          [ 7.4782e-02, -4.0186e-02, -5.4041e-02,  ...,  6.0871e-03,
            1.6597e-01, -8.7688e-02],
          [ 5.6568e-02,  5.9086e-02,  9.1099e-01,  ..., -4.8499e-01,
           -4.4338e-01, -3.3071e-01],
          ...,
          [ 1.3687e-01,  8.0689e-03,  5.3674e-01,  ...,  6.4444e-01,
            3.5084e-01, -1.7442e-02],
          [-2.8487e-02,  4.3667e-01,  1.0466e+00,  ..., -8.0288e-02,
           -6.0873e-02, -2.5624e-02],
          [-6.0843e-02, -1.5800e-01, -2.7730e-01,  ...,  4.4224e-01,
            3.7650e-01,  1.3905e-01]],

         [[ 3.1156e-01,  4.9859e-01,  6.1413e-01,  ..., -3.9132e-01,
           -1.9140e-01, -7.5594e-02],
          [-2.3902e-02, -1.2193e-01, -1.9248e-01,  ...,  1.0932e+00,
            7.5196e-01,  2.6939e-01],
          [-2.4425e-01, -3.3101e-01, -3.3450e-01,  ...,  9.1637e-01,
            2.0239e-01,  7.4754e-02],
          ...,
          [-7.8633e-02, -2.2259e-01, -3.1126e-01,  ...,  1.0238e+00,
            4.6807e-01,  2.9726e-01],
          [-4.1057e-02,  2.1352e-01,  5.4933e-01,  ..., -5.1812e-02,
           -4.6104e-02,  1.4153e-02],
          [-2.2941e-02, -4.0784e-02, -9.4297e-02,  ...,  2.7377e-01,
            1.2808e-01,  4.1316e-02]],

         [[-3.1916e-02, -1.9901e-02, -7.5594e-02,  ...,  9.6380e-01,
            4.4228e-01,  1.6175e-01],
          [-8.1777e-03,  1.1186e-01,  4.1587e-01,  ..., -3.8243e-01,
           -2.3134e-01, -1.2859e-01],
          [ 2.8381e-01,  1.7108e-01,  5.4225e-01,  ..., -2.7967e-01,
           -1.3795e-01, -1.2494e-01],
          ...,
          [ 3.3475e-01,  6.1965e-01,  1.2893e+00,  ..., -2.1037e-01,
           -5.8931e-02,  1.0094e-01],
          [-8.0970e-02, -2.0552e-01, -1.9528e-01,  ...,  4.4057e-01,
            1.8729e-01,  1.7780e-01],
          [ 1.6259e-02,  1.8337e-01, -1.5995e-02,  ..., -1.0156e-02,
            5.9640e-02,  7.7374e-02]],

         ...,

         [[-6.6453e-02,  1.2345e-03, -6.8085e-02,  ...,  3.7254e-01,
            1.4041e-02,  3.1234e-01],
          [ 2.8949e-01,  1.6551e-01,  1.7720e-01,  ..., -2.0364e-01,
            3.0634e-03, -1.0169e-01],
          [-2.4140e-02, -1.9360e-01, -1.4111e-01,  ...,  4.5580e-01,
            4.3118e-01, -6.7965e-02],
          ...,
          [ 5.0570e-01, -1.4967e-03, -3.4613e-02,  ..., -3.0379e-01,
           -2.3738e-01, -2.0096e-01],
          [-2.1262e-01, -2.7144e-01, -3.7416e-01,  ..., -6.5362e-02,
           -1.0723e-01, -7.3084e-02],
          [ 1.1407e-01,  3.3217e-01,  7.4943e-01,  ..., -1.7954e-01,
           -1.4103e-01, -9.4326e-02]],

         [[ 2.0149e-01,  3.7478e-02, -1.5635e-02,  ...,  4.2450e-01,
            3.9457e-01,  1.9363e-01],
          [ 9.5819e-02,  8.9870e-02, -1.9502e-02,  ...,  5.8218e-02,
            1.0696e-01,  1.4387e-01],
          [-1.1390e-01,  1.1920e-01,  4.2096e-01,  ..., -4.6602e-01,
           -4.1030e-01, -1.8032e-01],
          ...,
          [ 3.4323e-01,  7.0066e-01,  7.1163e-01,  ...,  1.7737e-02,
           -5.3021e-02, -9.1859e-02],
          [-2.7616e-02,  4.4884e-01,  8.0543e-01,  ..., -5.2696e-02,
           -1.3554e-02, -2.4830e-02],
          [-2.1704e-02, -1.2163e-01, -1.4341e-01,  ...,  4.8790e-01,
            2.7137e-01,  8.0601e-02]],

         [[-7.6124e-02, -5.0089e-02,  3.1880e-01,  ..., -1.8540e-01,
           -2.2487e-01, -2.1699e-01],
          [-1.2993e-01, -1.1733e-01, -1.5225e-01,  ...,  4.1419e-01,
            1.7313e-01,  3.4507e-01],
          [ 4.0585e-01,  4.9990e-01, -7.8072e-02,  ..., -3.6138e-02,
            3.7966e-01,  2.0279e-01],
          ...,
          [-6.9440e-02, -6.4434e-02, -8.5714e-02,  ...,  5.5404e-01,
            4.8774e-01,  6.1422e-01],
          [ 4.0948e-01,  1.6974e-01,  3.8686e-01,  ...,  4.5283e-01,
            3.5058e-01, -1.0249e-03],
          [-1.1901e-01, -1.0644e-01, -2.3372e-01,  ...,  2.5683e-02,
           -1.1525e-02,  7.9798e-02]]],


        ...,


        [[[-1.1181e-01, -2.3980e-01, -1.2242e-01,  ..., -7.5282e-02,
           -3.8972e-03,  3.2171e-01],
          [-9.0789e-02, -6.8418e-02,  2.6450e-01,  ...,  9.4353e-01,
            5.2219e-01, -2.7978e-02],
          [-1.0445e-01, -1.3948e-01, -1.9482e-01,  ..., -4.4340e-01,
           -3.2342e-01, -2.8784e-01],
          ...,
          [ 2.6013e-01,  5.5673e-01,  3.4526e-01,  ..., -6.6979e-02,
            1.4295e-01,  1.2825e-01],
          [ 2.1219e-01,  1.6811e-01,  3.4648e-01,  ..., -1.7945e-01,
           -5.7473e-02, -1.0353e-01],
          [ 1.5309e-01,  3.6817e-01,  8.1527e-01,  ..., -3.1560e-01,
           -2.1173e-01, -1.5874e-01]],

         [[-1.2270e-01, -2.3678e-01, -2.6664e-01,  ..., -2.1526e-01,
           -1.0538e-01, -6.1172e-03],
          [-1.0818e-01, -1.4680e-01, -3.5169e-02,  ...,  8.5371e-01,
            9.0588e-01,  4.2803e-01],
          [-6.3214e-02, -1.6061e-01, -9.8925e-02,  ...,  1.2289e-01,
           -8.6168e-02, -1.4889e-01],
          ...,
          [ 4.4536e-01,  7.0577e-01,  6.2773e-01,  ..., -3.8386e-02,
            1.9970e-01,  5.3314e-01],
          [ 3.8489e-02, -1.4918e-02, -8.7967e-02,  ...,  4.6038e-01,
            4.1668e-01,  3.7444e-01],
          [-4.0742e-02, -1.4857e-02,  7.6456e-02,  ..., -8.1258e-03,
           -5.5882e-02, -2.3799e-02]],

         [[ 8.9049e-02,  1.4258e-01,  3.8292e-01,  ...,  1.6133e-01,
            1.4335e-01,  1.8039e-01],
          [-5.2064e-02, -1.0456e-01, -1.9761e-01,  ..., -3.3952e-01,
           -2.6607e-01, -1.7296e-01],
          [-1.4956e-02, -9.2497e-02, -2.9282e-01,  ..., -1.7801e-02,
           -1.0827e-01, -8.1224e-02],
          ...,
          [-2.5800e-02,  8.0382e-02,  6.1825e-01,  ..., -3.8199e-01,
           -3.4327e-01, -1.8147e-01],
          [-7.3410e-03,  1.9852e-01,  9.0659e-01,  ..., -4.1740e-01,
           -3.9565e-01, -2.5667e-01],
          [-3.5025e-02, -4.8974e-02,  2.1491e-01,  ..., -1.5206e-01,
           -1.6684e-01, -1.3646e-01]],

         ...,

         [[ 1.0104e-02, -1.0961e-01, -5.9199e-02,  ..., -6.9714e-02,
           -7.6551e-02,  7.1423e-02],
          [-1.3282e-01, -1.5589e-01, -3.8276e-02,  ...,  4.1338e-01,
            6.7665e-01,  3.6271e-01],
          [ 1.4272e-01,  1.7284e-01,  1.2092e+00,  ...,  9.3071e-01,
            4.6994e-01,  2.0182e-01],
          ...,
          [-3.4049e-02, -8.2063e-03, -2.9202e-01,  ...,  9.4046e-01,
            8.4497e-01,  9.4680e-01],
          [ 2.6419e-02, -5.1124e-02, -3.0245e-01,  ...,  1.2244e+00,
            1.1257e+00,  8.6354e-01],
          [-1.0748e-01, -1.5871e-01, -3.2379e-01,  ...,  9.6839e-01,
            6.4024e-01,  4.8128e-01]],

         [[-1.4607e-01, -1.7049e-01, -1.4007e-01,  ..., -1.1310e-01,
            1.8378e-01,  1.6500e-01],
          [-2.1871e-01, -2.9596e-01, -3.2113e-01,  ...,  8.1939e-02,
            3.1708e-02, -4.5414e-02],
          [-1.0030e-01, -1.8369e-01, -1.4575e-01,  ...,  1.8224e+00,
            8.1758e-01,  6.1068e-01],
          ...,
          [ 6.0960e-01,  7.4829e-01,  3.7702e-01,  ..., -1.8992e-01,
            1.4258e-01,  2.0653e-01],
          [ 2.2604e-01,  4.1783e-01,  1.5157e-01,  ...,  8.4110e-02,
            2.3727e-01,  3.3342e-01],
          [-3.8223e-02,  8.7143e-02, -2.1989e-02,  ...,  3.9344e-01,
            1.7073e-01,  2.8971e-01]],

         [[ 7.9271e-02,  5.7074e-01, -2.1520e-02,  ..., -1.1682e-01,
           -1.7579e-01, -2.7374e-01],
          [ 8.3261e-01,  1.0145e+00,  6.4379e-02,  ..., -2.5812e-01,
           -4.6593e-01, -1.3197e-01],
          [ 7.9696e-01,  5.2228e-01, -1.6356e-01,  ..., -9.7610e-02,
            2.4922e-01,  7.9236e-01],
          ...,
          [-2.9078e-01, -2.4341e-01,  6.1101e-01,  ..., -1.4474e-01,
           -5.7758e-01, -4.4565e-01],
          [-2.4909e-01, -1.3604e-01,  2.8487e-01,  ..., -1.3682e-01,
           -4.3985e-01, -2.3561e-01],
          [-5.0680e-02,  5.5990e-02,  2.8167e-01,  ..., -1.4398e-01,
           -1.8487e-01, -8.6340e-02]]],


        [[[-2.9001e-02, -8.0461e-02, -2.2254e-03,  ...,  3.4950e-01,
            5.7094e-02,  1.2278e-01],
          [ 6.5166e-03,  2.1499e-01,  3.9006e-01,  ..., -4.7470e-02,
            1.0600e-01, -1.4044e-02],
          [-1.5046e-01, -1.8937e-01, -3.1187e-01,  ..., -3.9665e-01,
           -2.4484e-01, -9.6640e-02],
          ...,
          [-7.1152e-03,  4.8723e-03,  2.0440e-01,  ..., -4.2842e-01,
           -3.7245e-01, -1.6724e-01],
          [-7.7690e-02, -8.0645e-02, -3.2737e-02,  ...,  7.0092e-01,
            4.5857e-01,  4.7120e-02],
          [-4.4350e-02, -6.3200e-02, -7.5735e-02,  ..., -4.2281e-02,
           -6.1911e-02, -5.9531e-02]],

         [[-1.0874e-01, -1.8132e-01, -2.3345e-01,  ..., -2.0851e-01,
           -7.4565e-02, -4.1284e-02],
          [-4.0450e-02,  8.0014e-02,  2.1230e-01,  ...,  1.1867e+00,
            8.4320e-01,  2.3790e-01],
          [ 2.9826e-01,  5.6713e-01,  5.7998e-01,  ..., -7.3589e-02,
           -1.1340e-01, -7.4309e-02],
          ...,
          [-5.2319e-02, -4.7811e-02, -4.5951e-02,  ..., -2.6250e-01,
           -2.1904e-01, -9.8821e-02],
          [ 1.2224e-02,  1.4982e-01,  1.3104e-01,  ...,  3.2171e-01,
            1.0444e-01,  1.2559e-01],
          [ 4.7305e-02,  9.6597e-02,  2.5040e-02,  ..., -5.0359e-02,
           -1.3494e-03,  7.0893e-02]],

         [[ 1.6767e-01,  2.7971e-01,  4.6564e-01,  ...,  2.8015e-01,
            1.0774e-01,  1.1769e-01],
          [-1.1139e-01, -2.2390e-01, -2.6545e-01,  ..., -4.3858e-01,
           -2.4102e-01, -4.8400e-02],
          [ 5.1649e-03, -8.5191e-03, -1.2295e-01,  ...,  8.0194e-01,
            4.0112e-01, -5.4734e-02],
          ...,
          [ 1.2798e-02, -5.9980e-02,  4.2741e-02,  ...,  1.1245e+00,
            5.3877e-01, -5.4408e-02],
          [-1.2384e-02, -9.3497e-03, -1.2220e-02,  ..., -9.9615e-02,
           -2.4048e-02,  9.7492e-03],
          [-2.8891e-02,  1.5380e-03, -7.4589e-03,  ..., -3.1683e-03,
           -3.5122e-02, -5.1562e-03]],

         ...,

         [[ 1.0419e-01, -3.7381e-02,  1.1611e-01,  ...,  3.1713e-01,
            2.4542e-02,  2.7387e-01],
          [-1.1054e-01, -1.0922e-01, -6.1246e-02,  ..., -2.5663e-01,
           -3.7523e-02, -8.6062e-02],
          [-6.2047e-02,  1.1356e-01,  1.7014e-01,  ...,  5.3426e-01,
            2.4899e-01, -1.7628e-03],
          ...,
          [-4.8685e-02, -1.0132e-01, -1.5885e-01,  ...,  9.7035e-01,
            4.5634e-01,  1.8380e-01],
          [-5.9409e-02, -6.6526e-02, -8.1364e-02,  ..., -1.7169e-01,
           -1.5398e-01, -8.7368e-02],
          [-1.3775e-02,  8.7838e-03,  7.3347e-02,  ..., -3.3782e-02,
           -2.5098e-02, -1.5676e-02]],

         [[-2.5643e-02, -3.8952e-03,  1.1231e-01,  ...,  4.4662e-01,
            4.0375e-01,  9.5903e-02],
          [-5.6202e-02,  2.6972e-02,  9.8674e-02,  ...,  2.6287e-01,
            5.7475e-01,  2.9372e-01],
          [-2.0101e-03, -7.7510e-02, -5.1859e-02,  ..., -2.2054e-01,
           -1.6271e-01,  1.7489e-03],
          ...,
          [-4.3067e-02,  1.4013e-01,  1.3372e-01,  ..., -4.0701e-01,
           -3.2545e-01, -8.4230e-02],
          [ 2.6859e-02,  1.0184e-01,  2.8891e-01,  ...,  5.7155e-01,
            1.4783e-01, -1.1396e-02],
          [ 1.0030e-01,  4.3159e-02,  5.3251e-02,  ..., -7.9844e-02,
           -1.5021e-02, -1.8097e-02]],

         [[-1.4595e-02,  4.1319e-02, -1.0434e-01,  ..., -2.9386e-01,
           -2.9393e-01, -1.7882e-01],
          [ 1.5605e-01, -9.8099e-02, -1.8437e-01,  ..., -7.0900e-02,
           -1.0099e-01,  2.3862e-02],
          [-8.2990e-02, -1.8384e-02,  3.9302e-02,  ..., -5.9204e-02,
            2.0589e-01, -1.4739e-01],
          ...,
          [ 9.2245e-03,  6.3241e-02,  1.8415e-01,  ..., -1.0739e-01,
            1.5667e-01, -2.0955e-01],
          [-5.8324e-02, -5.9889e-02, -5.9691e-02,  ...,  1.1105e-01,
           -2.6855e-02,  2.6966e-01],
          [-9.0968e-02, -1.1653e-01, -1.3540e-01,  ...,  2.7048e-01,
            7.0940e-02, -4.9028e-02]]],


        [[[-7.2283e-02,  5.3242e-06, -1.2537e-01,  ..., -3.6909e-02,
           -4.0432e-02, -1.2870e-02],
          [ 2.9738e-02, -9.0208e-02, -1.4939e-01,  ..., -1.1575e-02,
           -3.9611e-02, -1.2851e-01],
          [-1.1368e-03,  8.5878e-02,  1.1499e+00,  ..., -9.8711e-02,
           -1.6185e-01, -3.1904e-02],
          ...,
          [-3.7651e-02,  2.6948e-01,  6.2389e-01,  ..., -4.9798e-01,
           -5.0601e-01, -1.6457e-01],
          [-3.7075e-02, -7.4441e-02, -6.7169e-02,  ...,  1.1513e+00,
            1.1047e+00,  4.2056e-01],
          [-1.3934e-02,  2.3110e-01,  6.1425e-01,  ..., -2.9988e-01,
           -2.7497e-01, -1.5535e-01]],

         [[ 6.4189e-01,  1.0841e+00,  1.3602e+00,  ..., -1.9836e-01,
           -1.2433e-01, -5.1696e-02],
          [ 1.8029e-01, -3.3437e-02, -1.0972e-01,  ...,  1.0138e+00,
            5.3192e-01,  2.5186e-01],
          [-2.0570e-01, -3.2785e-01, -3.4307e-01,  ..., -3.4972e-01,
           -1.7695e-01, -5.9314e-02],
          ...,
          [ 4.6911e-01,  7.8272e-01,  1.2014e+00,  ..., -6.5946e-01,
           -4.1826e-01, -1.9988e-01],
          [ 2.9877e-01,  1.8459e-01,  5.8291e-02,  ...,  6.7500e-01,
            6.4126e-01,  2.3666e-01],
          [ 7.2424e-02,  1.4965e-01,  2.7101e-01,  ..., -1.3135e-01,
           -7.8532e-02, -5.5658e-02]],

         [[-1.2613e-01, -1.8205e-01, -2.7218e-01,  ...,  7.2364e-01,
            2.9884e-01, -3.0641e-04],
          [ 8.1657e-03,  2.2355e-01,  5.2882e-01,  ..., -3.8305e-01,
           -1.7700e-01, -9.3458e-02],
          [ 3.8487e-01,  4.5329e-01,  1.0837e+00,  ...,  7.5589e-01,
            3.6727e-01,  2.6407e-01],
          ...,
          [-7.2355e-02, -1.6600e-01, -7.3874e-02,  ...,  1.5793e+00,
            7.3415e-01, -2.0445e-02],
          [ 1.2154e-02,  2.5559e-01,  5.1223e-01,  ..., -2.2214e-01,
           -1.3171e-01,  3.5690e-03],
          [-3.8189e-02, -7.5509e-02, -1.2035e-02,  ..., -7.7224e-03,
           -2.7375e-02, -8.4515e-02]],

         ...,

         [[-1.3251e-01, -1.8230e-02, -1.4040e-01,  ...,  1.7689e-01,
           -2.7634e-02,  2.1177e-02],
          [ 3.9579e-01,  3.7851e-01,  2.2339e-01,  ..., -1.4981e-01,
           -3.4347e-02, -1.0266e-01],
          [-3.3756e-04, -2.0567e-01, -2.4259e-01,  ...,  6.6921e-01,
            8.1090e-02,  3.9280e-01],
          ...,
          [-1.8066e-01, -2.0184e-01, -4.9982e-01,  ...,  2.2452e+00,
            1.0433e+00,  9.4168e-01],
          [ 2.3412e-01,  1.5095e-01, -2.0737e-02,  ..., -3.7328e-01,
           -1.7414e-01, -1.1398e-01],
          [-7.4851e-02, -1.2918e-01, -2.4440e-01,  ...,  7.8000e-01,
            5.6403e-01,  2.9963e-01]],

         [[ 1.6167e-01, -1.5616e-02, -8.3539e-02,  ...,  1.6285e-01,
            3.6807e-02,  1.2664e-01],
          [ 1.7148e-01,  5.5283e-02, -8.5373e-02,  ..., -8.5442e-02,
           -1.0091e-01, -6.3013e-02],
          [-4.1532e-02,  5.2311e-01,  9.4856e-01,  ..., -1.0042e-01,
           -1.3674e-02, -5.0829e-02],
          ...,
          [-8.3947e-03,  1.7153e-01,  3.2893e-02,  ..., -3.2762e-01,
           -2.4920e-01, -6.0408e-02],
          [ 3.8477e-01,  3.5678e-01,  2.6632e-01,  ...,  1.3007e+00,
            1.0828e+00,  4.5551e-01],
          [ 1.4014e-01,  4.3396e-01,  4.7056e-01,  ..., -2.0901e-01,
           -1.3785e-01, -2.1573e-02]],

         [[-9.5441e-02, -9.8721e-02,  4.1590e-01,  ...,  5.8188e-02,
            1.1586e-01, -7.8749e-02],
          [-1.8588e-01, -1.4471e-01, -9.8197e-02,  ...,  1.2607e-01,
            1.0127e-01,  3.5296e-01],
          [ 3.5947e-01,  4.9455e-01,  6.6820e-02,  ..., -2.6978e-01,
           -2.3982e-01, -2.5859e-01],
          ...,
          [-8.1056e-03,  3.2226e-01,  1.3108e+00,  ..., -5.7864e-01,
           -3.4547e-01, -4.6089e-01],
          [-1.4639e-01, -1.4384e-01, -7.4725e-02,  ...,  3.4417e-01,
           -6.9169e-02,  2.8609e-01],
          [ 2.0396e-02, -2.0405e-02,  2.5683e-01,  ..., -2.4720e-01,
           -1.3365e-01, -1.8511e-01]]]], device='cuda:0')

And that’s it. We’re done with our CSI human reconstruction tutorials. Thanks for reading.

Total running time of the script: (0 minutes 50.698 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access documentation for Pysensing

View Docs

Tutorials

Get started with tutorials and examples

View Tutorials

Get Started

Find resources and how to start using pysensing

View Resources