.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "csi/Reconstruction_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_csi_Reconstruction_tutorial.py: CSI reconstruction Tutorial ============================================================== .. GENERATED FROM PYTHON SOURCE LINES 7-10 .. code-block:: Python # !pip install pysensing .. GENERATED FROM PYTHON SOURCE LINES 11-13 In this tutorial, we will be implementing codes for CSI human pose estimation task .. GENERATED FROM PYTHON SOURCE LINES 13-22 .. code-block:: Python import sys sys.path.append('../..') import pysensing.csi.dataset.get_dataloader as get_dataloader import pysensing.csi.model.get_model as get_model import pysensing.csi.inference.predict as predict import pysensing.csi.inference.train as train import pysensing.csi.inference.embedding as embedding import torch .. GENERATED FROM PYTHON SOURCE LINES 23-26 Load the data ----------------------------------- CSI reconstruction dataset: .. GENERATED FROM PYTHON SOURCE LINES 26-52 .. code-block:: Python # HandFi # CSI size : 6, 20, 114 # image : 144, 144 # joints2d : 2, 42 # joints3d : 2, 21 # train number : 3600 # test number : 400 train_loader, test_loader = get_dataloader.load_recon_dataset('HandFi', batch_size=32, return_train=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") for data in train_loader: ((joints,image), csi) = data joint = joints[:,:,0:21].to(device,dtype=torch.float) img=image.to(device,dtype=torch.float) csi=csi.to(device,dtype=torch.float) joint2d = joint[:,0:2,:] joint2d = joint2d.view(-1,42) joint3d = joint[:,2,:] print('data:', csi) print('img:', img) print('joint:', joint) break .. rst-class:: sphx-glr-script-out .. code-block:: pytb Traceback (most recent call last): File "/data1/msc/zyj/yunjiao_csi/1028/yunjiao_csi/tutorials/csi_source/Reconstruction_tutorial.py", line 36, in train_loader, test_loader = get_dataloader.load_recon_dataset('HandFi', batch_size=32, return_train=True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data1/msc/zyj/yunjiao_csi/1028/yunjiao_csi/tutorials/csi_source/../../pysensing/csi/dataset/get_dataloader.py", line 200, in load_recon_dataset raise ValueError("Download failed.") ValueError: Download failed. .. GENERATED FROM PYTHON SOURCE LINES 53-56 Load the model ----------------------------------- For HandFi dataset, model zoo contains AutoEncoder. .. GENERATED FROM PYTHON SOURCE LINES 56-60 .. code-block:: Python model = get_model.load_recon_model('HandFi', 'AutoEncoder') print(model) .. GENERATED FROM PYTHON SOURCE LINES 61-63 Model train ------------------------ .. GENERATED FROM PYTHON SOURCE LINES 63-68 .. code-block:: Python optimizer = torch.optim.Adam(model.parameters(), lr=0.001) epoch_num = 1 train.recon_train(train_loader, model, epoch_num, optimizer, device) .. GENERATED FROM PYTHON SOURCE LINES 69-71 Model inference ------------------------ .. GENERATED FROM PYTHON SOURCE LINES 71-80 .. code-block:: Python model = get_model.load_pretrain(model, 'HandFi', 'AutoEncoder', device=device) output = predict.recon_predict(csi, 'HandFi', model, device) _, mask, twod, threed = output print("mask:", mask.shape) print("twod:", twod.shape) print("threed:", threed.shape) .. GENERATED FROM PYTHON SOURCE LINES 81-83 Evaluate the loss ------------------------ .. GENERATED FROM PYTHON SOURCE LINES 83-89 .. code-block:: Python IoUerr = train.IoU(img,mask) mPAerr = train.mPA(img,mask) mpjpe, pck = train.mpjpe_pck(joint2d,joint3d, twod, threed) print( f'mPA: {mPAerr:.3f} | => IoU: {IoUerr:.3f} | => mpjpe: {mpjpe:.3f} | =>pck: {pck:.3f}\n') .. GENERATED FROM PYTHON SOURCE LINES 90-92 Generate embedding ------------------------ .. GENERATED FROM PYTHON SOURCE LINES 92-97 .. code-block:: Python csi_embedding = embedding.recon_csi_embedding(csi, 'HandFi', model, device) print('csi_embedding: ', csi_embedding) .. GENERATED FROM PYTHON SOURCE LINES 98-99 And that's it. We're done with our CSI reconstruction tutorials. Thanks for reading. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.592 seconds) .. _sphx_glr_download_csi_Reconstruction_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: Reconstruction_tutorial.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: Reconstruction_tutorial.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_