from PointTransformerV3.model import * | |
class PTV3(PointTransformerV3): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def encode(self, data_dict): | |
point = Point(data_dict) | |
point.serialization(order=self.order, shuffle_orders=self.shuffle_orders) | |
point.sparsify() | |
point = self.embedding(point) | |
point = self.enc(point) | |
return point.feats |