.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intraoral_scan/tooth_alignment.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_intraoral_scan_tooth_alignment.py: Tutorial for tooth alignment ============================ .. GENERATED FROM PYTHON SOURCE LINES 6-30 .. code-block:: Python import os import sys import torch import requests from pysensing.intraoral_scan.preprocessing.ta_utils import * from pysensing.intraoral_scan.inference.utils import ta_dataloader from pysensing.intraoral_scan.inference.ta_predict import predict def download_weights(remote_url, local_path): if not os.path.exists(local_path): os.makedirs(os.path.dirname(local_path), exist_ok=True) print(f"Downloading weights from {remote_url}...") response = requests.get(remote_url, stream=True) response.raise_for_status() with open(local_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print("Download complete.") else: print("Weights already exist. Skipping download.") .. GENERATED FROM PYTHON SOURCE LINES 31-34 Load Model (picking one from following three models) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 34-66 .. code-block:: Python # Load CurveNet from pysensing.intraoral_scan.models.tooth_alignment.curvenet import CurveNet os.environ["CUDA_VISIBLE_DEVICES"] = "0" model = CurveNet() REMOTE_WEIGHT_URL = "https://pysensing.oss-ap-southeast-1.aliyuncs.com/pretrain/intraoral_scan/tooth_alignment/CurveNet.pth" LOCAL_WEIGHT_PATH = "models/CurveNet" download_weights(REMOTE_WEIGHT_URL, LOCAL_WEIGHT_PATH) model.load_state_dict(torch.load(LOCAL_WEIGHT_PATH, weights_only=True, map_location="cuda")) # Load DGCNN from pysensing.intraoral_scan.models.tooth_alignment.dgcnn import DGCNN os.environ["CUDA_VISIBLE_DEVICES"] = "0" model = DGCNN() REMOTE_WEIGHT_URL = "https://pysensing.oss-ap-southeast-1.aliyuncs.com/pretrain/intraoral_scan/tooth_alignment/DGCNN.pth" LOCAL_WEIGHT_PATH = "models/DGCNN_TA" download_weights(REMOTE_WEIGHT_URL, LOCAL_WEIGHT_PATH) model.load_state_dict(torch.load(LOCAL_WEIGHT_PATH, weights_only=True, map_location="cuda")) # Load TANet from pysensing.intraoral_scan.models.tooth_alignment.tanet import TANet os.environ["CUDA_VISIBLE_DEVICES"] = "0" model = TANet() REMOTE_WEIGHT_URL = "https://pysensing.oss-ap-southeast-1.aliyuncs.com/pretrain/intraoral_scan/tooth_alignment/TANet.pth" LOCAL_WEIGHT_PATH = "models/TANet" download_weights(REMOTE_WEIGHT_URL, LOCAL_WEIGHT_PATH) model.load_state_dict(torch.load(LOCAL_WEIGHT_PATH, weights_only=True, map_location="cuda")) .. GENERATED FROM PYTHON SOURCE LINES 67-70 Load Dataset ~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 70-78 .. code-block:: Python dataset_path = "../datasets/tooth_alignment/example/data" batch_size = 3 data_loader = ta_dataloader.DataLoader(dataset_path, batch_size) print(len(data_loader)) .. GENERATED FROM PYTHON SOURCE LINES 79-82 Model Inference ~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 82-92 .. code-block:: Python # Predict the results num_point_tooth = 400 # the number of points for each tooth pointcloud preds = predict(data_loader, num_point_tooth, model, "cuda") # transform the prediction to pose root = "../datasets/tooth_alignment/example" ans_pose = trans_pred(data_loader, preds, root) .. GENERATED FROM PYTHON SOURCE LINES 93-96 Visualization of Tooth Alignment Results ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 96-110 .. code-block:: Python data_idx = 0 # Set the data_idx for visualization (_, dirs) = get_idx_dirs(root) objs = getTooth(data_idx, root, dirs) # Show original dentition pose = getAxis(f'{root}/{dirs[data_idx]}/TeethAxis_Ori.txt', keep_fdi=True) showTooth(objs, pose).show() # Show aligned dentition showTooth(objs, ans_pose[data_idx]).show() # Show GT dentition pose = getAxis(f'{root}/{dirs[data_idx]}/TeethAxis_T2.txt', keep_fdi=True) showTooth(objs, pose).show() .. _sphx_glr_download_intraoral_scan_tooth_alignment.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tooth_alignment.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tooth_alignment.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_