发布时间:2024-05-24 17:01
Apollo中的KD-Tree实现代码.
KD-Tree参数:
struct AABoxKDTreeParams {
/// The maximum depth of the kdtree.
int max_depth = -1;
/// The maximum number of items in one leaf node.
int max_leaf_size = -1;
/// The maximum dimension size of leaf node.
// 可以看做是拆分到最后时像素间的最大距离
double max_leaf_dimension = -1.0;
};
对外接口,模板类:
/**
* @class AABoxKDTree2d
* @brief The class of KD-tree of Aligned Axis Bounding Box(AABox).
*/
template <class ObjectType>
class AABoxKDTree2d {
public:
using ObjectPtr = const ObjectType *;
/**
* 使用 对象 构成的vector + KD-Tree参数 进行构造 (建树)
* @brief Constructor which takes a vector of objects and parameters.
* @param params Parameters to build the KD-tree.
*/
AABoxKDTree2d(const std::vector<ObjectType> &objects,
const AABoxKDTreeParams ¶ms) {
if (!objects.empty()) {
std::vector<ObjectPtr> object_ptrs;
for (const auto &object : objects) {
object_ptrs.push_back(&object);
}
// 具体建树接口
root_.reset(new AABoxKDTree2dNode<ObjectType>(object_ptrs, params, 0));
}
}
/**
* @brief Get the nearest object to a target point. 最近邻查找
* @param point The target point. Search it\'s nearest object.
* @return The nearest object to the target point.
*/
ObjectPtr GetNearestObject(const Vec2d &point) const {
return root_ == nullptr ? nullptr : root_->GetNearestObject(point);
}
/**
* @brief Get objects within a distance to a point. 根据距离查找
* @param point The center point of the range to search objects.
* @param distance The radius of the range to search objects.
* @return All objects within the specified distance to the specified point.
*/
std::vector<ObjectPtr> GetObjects(const Vec2d &point,
const double distance) const {
if (root_ == nullptr) {
return {};
}
return root_->GetObjects(point, distance);
}
/**
* @brief Get the axis-aligned bounding box of the objects.
* @return The axis-aligned bounding box of the objects.
*/
AABox2d GetBoundingBox() const {
return root_ == nullptr ? AABox2d() : root_->GetBoundingBox();
}
private:
std::unique_ptr<AABoxKDTree2dNode<ObjectType>> root_ = nullptr;
};
/**
* @brief Constructor which takes a vector of objects,
* parameters and depth of the node.
* @param objects Objects to build the KD-tree node.
* @param params Parameters to build the KD-tree.
* @param depth Depth of the KD-tree node.
*/
AABoxKDTree2dNode(const std::vector<ObjectPtr> &objects,
const AABoxKDTreeParams ¶ms, int depth)
: depth_(depth) {
CHECK(!objects.empty());
//Step 1.计算 Dx 边界
ComputeBoundary(objects);
//Step 2.对当前Node计算划分点 partition_position_
ComputePartition();
//Step 3.分裂当前Node为左右子树
if (SplitToSubNodes(objects, params)) {//检查是否分裂是否满足参数条件
std::vector<ObjectPtr> left_subnode_objects;
std::vector<ObjectPtr> right_subnode_objects;
//Step 3.1.分裂Node
PartitionObjects(objects, &left_subnode_objects, &right_subnode_objects);
//Step 3.2. Split to sub-nodes. 分裂到左右子树
if (!left_subnode_objects.empty()) {
left_subnode_.reset(new AABoxKDTree2dNode<ObjectType>(//在构造函数中递归
left_subnode_objects, params, depth + 1));
}
if (!right_subnode_objects.empty()) {
right_subnode_.reset(new AABoxKDTree2dNode<ObjectType>(
right_subnode_objects, params, depth + 1));
}
} else {
InitObjects(objects);//?这句会运行吗
}
}
void ComputeBoundary(const std::vector<ObjectPtr> &objects) {
//初始化值为正无穷大
min_x_ = std::numeric_limits<double>::infinity();
min_y_ = std::numeric_limits<double>::infinity();
max_x_ = -std::numeric_limits<double>::infinity();
max_y_ = -std::numeric_limits<double>::infinity();
for (ObjectPtr object : objects) {
min_x_ = std::fmin(min_x_, object->aabox().min_x());
max_x_ = std::fmax(max_x_, object->aabox().max_x());
min_y_ = std::fmin(min_y_, object->aabox().min_y());
max_y_ = std::fmax(max_y_, object->aabox().max_y());
}
mid_x_ = (min_x_ + max_x_) / 2.0;
mid_y_ = (min_y_ + max_y_) / 2.0;
CHECK(!std::isinf(max_x_) && !std::isinf(max_y_) && !std::isinf(min_x_) &&
!std::isinf(min_y_))
<< \"the provided object box size is infinity\";
}
KD-Tree划分计算:长和宽哪个更大就按照哪个方向进行划分,划分的点也比较粗暴,直接采用中点进行划分,因此这里和笔记中的选择划分点有区别,笔记中或一般 kd-tree 的节点(划分点)选取依然是输入数据中的点,而这里不是,这里直接以物理空间进行“均分”。
void ComputePartition() {
if (max_x_ - min_x_ >= max_y_ - min_y_) {
partition_ = PARTITION_X;
partition_position_ = (min_x_ + max_x_) / 2.0;
} else {
partition_ = PARTITION_Y;
partition_position_ = (min_y_ + max_y_) / 2.0;
}
}
在 step 3.1中构建了当前节点的数据,于是看可以看出建树的过程是一个前序遍历建树的过程。
if (SplitToSubNodes(objects, params)) {//检查是否分裂是否满足参数条件
std::vector<ObjectPtr> left_subnode_objects;
std::vector<ObjectPtr> right_subnode_objects;
//Step 3.1.划分空间
PartitionObjects(objects, &left_subnode_objects, &right_subnode_objects);
//Step 3.2. Split to sub-nodes. 分裂到左右子树
if (!left_subnode_objects.empty()) {
// 在构造函数中递归
left_subnode_.reset(new AABoxKDTree2dNode<ObjectType>(
left_subnode_objects, params, depth + 1));
}
if (!right_subnode_objects.empty()) {
right_subnode_.reset(new AABoxKDTree2dNode<ObjectType>(
right_subnode_objects, params, depth + 1));
}
}
首先检测是否需要继续分裂:
bool SplitToSubNodes(const std::vector<ObjectPtr> &objects,
const AABoxKDTreeParams ¶ms) {
if (params.max_depth >= 0 && depth_ >= params.max_depth) {
return false;
}
if (static_cast<int>(objects.size()) <= std::max(1, params.max_leaf_size)) {
return false;
}
// 这里的维度距离就可以看作是像素之间的距离
if (params.max_leaf_dimension >= 0.0 &&
std::max(max_x_ - min_x_, max_y_ - min_y_) <=
params.max_leaf_dimension) {
return false;
}
return true;
}
开始划分左右子空间:
//将当前Node分裂为左右子树
void PartitionObjects(const std::vector<ObjectPtr> &objects,
std::vector<ObjectPtr> *const left_subnode_objects,
std::vector<ObjectPtr> *const right_subnode_objects) {
left_subnode_objects->clear();
right_subnode_objects->clear();
std::vector<ObjectPtr> other_objects;
if (partition_ == PARTITION_X) {
for (ObjectPtr object : objects) {
if (object->aabox().max_x() <= partition_position_) {
left_subnode_objects->push_back(object);
} else if (object->aabox().min_x() >= partition_position_) {
right_subnode_objects->push_back(object);
} else {
other_objects.push_back(object);//在分界线上
}
}
} else {
for (ObjectPtr object : objects) {
if (object->aabox().max_y() <= partition_position_) {
left_subnode_objects->push_back(object);
} else if (object->aabox().min_y() >= partition_position_) {
right_subnode_objects->push_back(object);
} else {
other_objects.push_back(object);
}
}
}
// 保存当前节点数据
InitObjects(other_objects);
}
//$ 对当前节点数据进行处理保存
void InitObjects(const std::vector<ObjectPtr> &objects) {
num_objects_ = static_cast<int>(objects.size()); //当前节点个数?
objects_sorted_by_min_ = objects;
objects_sorted_by_max_ = objects;
std::sort(objects_sorted_by_min_.begin(), objects_sorted_by_min_.end(),
[&](ObjectPtr obj1, ObjectPtr obj2) {
return partition_ == PARTITION_X
? obj1->aabox().min_x() < obj2->aabox().min_x()
: obj1->aabox().min_y() < obj2->aabox().min_y();
});
std::sort(objects_sorted_by_max_.begin(), objects_sorted_by_max_.end(),
[&](ObjectPtr obj1, ObjectPtr obj2) {
return partition_ == PARTITION_X
? obj1->aabox().max_x() > obj2->aabox().max_x()
: obj1->aabox().max_y() > obj2->aabox().max_y();
});
objects_sorted_by_min_bound_.reserve(
num_objects_); //记录从小到大排序的坐标x值
for (ObjectPtr object : objects_sorted_by_min_) {
objects_sorted_by_min_bound_.push_back(partition_ == PARTITION_X
? object->aabox().min_x()
: object->aabox().min_y());
}
objects_sorted_by_max_bound_.reserve(num_objects_);
for (ObjectPtr object : objects_sorted_by_max_) {
objects_sorted_by_max_bound_.push_back(partition_ == PARTITION_X
? object->aabox().max_x()
: object->aabox().max_y());
}
}
构建树的目的就是降低查找的时间复杂度。可以查找最近邻,也可以查找k近邻。
(1)最近邻查找
1.首先要找到该目标点的叶子节点进行判断,记录最近距离与nearest_object;
2.回溯判断另一半空间是否需要计算,若需要在另一个子树继续搜索最近邻,更新记录值。
3.当回溯到根节点时,算法结束,此时保存的最近邻节点就是最终的最近邻。
参考阅读 详细看查询距离。
void GetNearestObjectInternal(
const Vec2d &point,
double *const min_distance_sqr, //指针不变,但指针指向的内容可以变
ObjectPtr *const nearest_object) const {
// Step 4.距离大于当前最小距离的子树就不计算了,直接返回(回溯)-->剪枝操作
if (LowerDistanceSquareToPoint(point) >= *min_distance_sqr - kMathEpsilon) {
return;
}
// !后序访问:注意这里不是遍历访问
// Step 1.DFS:将按point搜索,改为按当前节点划分方式的值进行搜索
const double pvalue = (partition_ == PARTITION_X ? point.x() : point.y());
const bool search_left_first = (pvalue < partition_position_);
// 这里的 DFS 不是遍历访问,是 if else 关系,访问另一子空间是在回溯时访问
if (search_left_first) {
if (left_subnode_ != nullptr) {
left_subnode_->GetNearestObjectInternal(point, min_distance_sqr,
nearest_object); //递归查找
}
} else {
if (right_subnode_ != nullptr) {
right_subnode_->GetNearestObjectInternal(point, min_distance_sqr,
nearest_object);
}
}
// 距离 == 0 , 找到最近邻,直接返回,不用深入了
if (*min_distance_sqr <= kMathEpsilon) { //回溯的时候min_distance_sqr值会改变
return;
}
// Step 2.欧氏距离找最近的object
if (search_left_first) {
for (int i = 0; i < num_objects_; ++i) {
const double bound = objects_sorted_by_min_bound_[i];
// ?框的最小边界都比它大,说明在框外面的意思?
if (bound > pvalue && Square(bound - pvalue) > *min_distance_sqr) {
break;
}
ObjectPtr object = objects_sorted_by_min_[i];
const double distance_sqr = object->DistanceSquareTo(point);
if (distance_sqr < *min_distance_sqr) {
*min_distance_sqr = distance_sqr;
*nearest_object = object;
}
}
} else {
for (int i = 0; i < num_objects_; ++i) {
const double bound = objects_sorted_by_max_bound_[i];
// ?框的最大边界都比它小,说明在框外面的意思?
if (bound < pvalue && Square(bound - pvalue) > *min_distance_sqr) {
break;
}
ObjectPtr object = objects_sorted_by_max_[i];
const double distance_sqr = object->DistanceSquareTo(point);
if (distance_sqr < *min_distance_sqr) {
*min_distance_sqr = distance_sqr;
*nearest_object = object;
}
}
}
// 距离 == 0 , 找到最近邻,直接返回,不用深入了
if (*min_distance_sqr <= kMathEpsilon) {
return;
}
// Step 3.回溯时访问另一半子空间,看看是否还有距离更近的点
if (search_left_first) {
if (right_subnode_ != nullptr) {
right_subnode_->GetNearestObjectInternal(point, min_distance_sqr,
nearest_object);
}
} else {
if (left_subnode_ != nullptr) {
left_subnode_->GetNearestObjectInternal(point, min_distance_sqr,
nearest_object);
}
}
}
按距离查找原理类似。