map.hpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749
  1. #pragma once
  2. #include <utility>
  3. #include <type_traits>
  4. #include <types/allocator.hpp>
  5. #include <types/pair.hpp>
  6. #include <types/types.h>
  7. namespace types {
  8. template <typename Key, typename Value, template <typename _T> class _Allocator = kernel_allocator>
  9. class map {
  10. public:
  11. using key_type = std::add_const_t<Key>;
  12. using value_type = Value;
  13. using pair_type = pair<key_type, value_type>;
  14. struct node {
  15. node* parent = nullptr;
  16. node* left = nullptr;
  17. node* right = nullptr;
  18. enum class node_color {
  19. RED,
  20. BLACK,
  21. } color
  22. = node_color::RED;
  23. pair_type v;
  24. constexpr node(pair_type&& pair)
  25. : v(std::move(pair))
  26. {
  27. }
  28. constexpr node(const pair_type& pair)
  29. : v(pair)
  30. {
  31. }
  32. constexpr node* grandparent(void) const
  33. {
  34. return this->parent->parent;
  35. }
  36. constexpr node* uncle(void) const
  37. {
  38. node* pp = this->grandparent();
  39. return (this->parent == pp->left) ? pp->right : pp->left;
  40. }
  41. constexpr node* leftmost(void)
  42. {
  43. node* nd = this;
  44. while (nd->left)
  45. nd = nd->left;
  46. return nd;
  47. }
  48. constexpr const node* leftmost(void) const
  49. {
  50. const node* nd = this;
  51. while (nd->left)
  52. nd = nd->left;
  53. return nd;
  54. }
  55. constexpr node* rightmost(void)
  56. {
  57. node* nd = this;
  58. while (nd->right)
  59. nd = nd->right;
  60. return nd;
  61. }
  62. constexpr const node* rightmost(void) const
  63. {
  64. const node* nd = this;
  65. while (nd->right)
  66. nd = nd->right;
  67. return nd;
  68. }
  69. constexpr node* next(void)
  70. {
  71. if (this->right) {
  72. return this->right->leftmost();
  73. } else {
  74. if (this->is_root()) {
  75. return nullptr;
  76. } else if (this->is_left_child()) {
  77. return this->parent;
  78. } else {
  79. node* ret = this;
  80. do {
  81. ret = ret->parent;
  82. } while (!ret->is_root() && !ret->is_left_child());
  83. return ret->parent;
  84. }
  85. }
  86. }
  87. constexpr const node* next(void) const
  88. {
  89. if (this->right) {
  90. return this->right->leftmost();
  91. } else {
  92. if (this->is_root()) {
  93. return nullptr;
  94. } else if (this->is_left_child()) {
  95. return this->parent;
  96. } else {
  97. const node* ret = this;
  98. do {
  99. ret = ret->parent;
  100. } while (!ret->is_root() && !ret->is_left_child());
  101. return ret->parent;
  102. }
  103. }
  104. }
  105. constexpr node* prev(void)
  106. {
  107. if (this->left) {
  108. return this->left->rightmost();
  109. } else {
  110. if (this->is_root()) {
  111. return nullptr;
  112. } else if (this->is_right_child()) {
  113. return this->parent;
  114. } else {
  115. node* ret = this;
  116. do {
  117. ret = ret->parent;
  118. } while (!ret->is_root() && !ret->is_right_child());
  119. return ret->parent;
  120. }
  121. }
  122. }
  123. static constexpr bool is_red(node* nd)
  124. {
  125. return nd && nd->color == node_color::RED;
  126. }
  127. static constexpr bool is_black(node* nd)
  128. {
  129. return !node::is_red(nd);
  130. }
  131. constexpr const node* prev(void) const
  132. {
  133. if (this->left) {
  134. return this->left->rightmost();
  135. } else {
  136. if (this->is_root()) {
  137. return nullptr;
  138. } else if (this->is_right_child()) {
  139. return this->parent;
  140. } else {
  141. const node* ret = this;
  142. do {
  143. ret = ret->parent;
  144. } while (!ret->is_root() && !ret->is_right_child());
  145. return ret->parent;
  146. }
  147. }
  148. }
  149. constexpr bool is_root(void) const
  150. {
  151. return this->parent == nullptr;
  152. }
  153. constexpr bool is_full(void) const
  154. {
  155. return this->left && this->right;
  156. }
  157. constexpr bool has_child(void) const
  158. {
  159. return this->left || this->right;
  160. }
  161. constexpr bool is_leaf(void) const
  162. {
  163. return !this->has_child();
  164. }
  165. constexpr bool is_left_child(void) const
  166. {
  167. return this == this->parent->left;
  168. }
  169. constexpr bool is_right_child(void) const
  170. {
  171. return this == this->parent->right;
  172. }
  173. constexpr void tored(void)
  174. {
  175. this->color = node_color::RED;
  176. }
  177. constexpr void toblack(void)
  178. {
  179. this->color = node_color::BLACK;
  180. }
  181. static constexpr void swap(node* first, node* second)
  182. {
  183. if (node::is_red(first)) {
  184. first->color = second->color;
  185. second->color = node_color::RED;
  186. } else {
  187. first->color = second->color;
  188. second->color = node_color::BLACK;
  189. }
  190. if (first->parent == second) {
  191. node* tmp = first;
  192. first = second;
  193. second = tmp;
  194. }
  195. bool f_is_left_child = first->parent ? first->is_left_child() : false;
  196. bool s_is_left_child = second->parent ? second->is_left_child() : false;
  197. node* fp = first->parent;
  198. node* fl = first->left;
  199. node* fr = first->right;
  200. node* sp = second->parent;
  201. node* sl = second->left;
  202. node* sr = second->right;
  203. if (second->parent != first) {
  204. first->parent = sp;
  205. if (sp) {
  206. if (s_is_left_child)
  207. sp->left = first;
  208. else
  209. sp->right = first;
  210. }
  211. first->left = sl;
  212. if (sl)
  213. sl->parent = first;
  214. first->right = sr;
  215. if (sr)
  216. sr->parent = first;
  217. second->parent = fp;
  218. if (fp) {
  219. if (f_is_left_child)
  220. fp->left = second;
  221. else
  222. fp->right = second;
  223. }
  224. second->left = fl;
  225. if (fl)
  226. fl->parent = second;
  227. second->right = fr;
  228. if (fr)
  229. fr->parent = second;
  230. } else {
  231. first->left = sl;
  232. if (sl)
  233. sl->parent = first;
  234. first->right = sr;
  235. if (sr)
  236. sr->parent = first;
  237. second->parent = fp;
  238. if (fp) {
  239. if (f_is_left_child)
  240. fp->left = second;
  241. else
  242. fp->right = second;
  243. }
  244. first->parent = second;
  245. if (s_is_left_child) {
  246. second->left = first;
  247. second->right = fr;
  248. if (fr)
  249. fr->parent = second;
  250. } else {
  251. second->right = first;
  252. second->left = fl;
  253. if (fl)
  254. fl->parent = second;
  255. }
  256. }
  257. }
  258. };
  259. using allocator_type = _Allocator<node>;
  260. template <bool Const>
  261. class iterator {
  262. public:
  263. using node_pointer_type = std::conditional_t<Const, const node*, node*>;
  264. using value_type = std::conditional_t<Const, const pair_type, pair_type>;
  265. using pointer_type = std::add_pointer_t<value_type>;
  266. using reference_type = std::add_lvalue_reference_t<value_type>;
  267. friend class map;
  268. private:
  269. node_pointer_type p;
  270. public:
  271. explicit constexpr iterator(node_pointer_type ptr)
  272. : p { ptr }
  273. {
  274. }
  275. constexpr iterator(const iterator& iter)
  276. : p { iter.p }
  277. {
  278. }
  279. constexpr iterator(iterator&& iter)
  280. : p { iter.p }
  281. {
  282. iter.p = nullptr;
  283. }
  284. constexpr ~iterator()
  285. {
  286. #ifndef NDEBUG
  287. p = nullptr;
  288. #endif
  289. }
  290. constexpr iterator& operator=(const iterator& iter)
  291. {
  292. p = iter.p;
  293. return *this;
  294. }
  295. constexpr iterator& operator=(iterator&& iter)
  296. {
  297. p = iter.p;
  298. iter.p = nullptr;
  299. return *this;
  300. }
  301. constexpr bool operator==(const iterator& iter) const
  302. {
  303. return p == iter.p;
  304. }
  305. constexpr bool operator!=(const iterator& iter) const
  306. {
  307. return !this->operator==(iter);
  308. }
  309. constexpr reference_type operator*(void) const
  310. {
  311. return p->v;
  312. }
  313. constexpr pointer_type operator&(void) const
  314. {
  315. return &p->v;
  316. }
  317. constexpr pointer_type operator->(void) const
  318. {
  319. return this->operator&();
  320. }
  321. constexpr iterator& operator++(void)
  322. {
  323. p = p->next();
  324. return *this;
  325. }
  326. constexpr iterator operator++(int)
  327. {
  328. iterator ret(p);
  329. (void)this->operator++();
  330. return ret;
  331. }
  332. constexpr iterator& operator--(void)
  333. {
  334. p = p->prev();
  335. return *this;
  336. }
  337. constexpr iterator operator--(int)
  338. {
  339. iterator ret(p);
  340. (void)this->operator--();
  341. return ret;
  342. }
  343. explicit constexpr operator bool(void)
  344. {
  345. return p;
  346. }
  347. };
  348. using iterator_type = iterator<false>;
  349. using const_iterator_type = iterator<true>;
  350. private:
  351. node* root = nullptr;
  352. private:
  353. static constexpr node* newnode(node* parent, const pair_type& val)
  354. {
  355. auto* ptr = allocator_traits<allocator_type>::allocate_and_construct(val);
  356. ptr->parent = parent;
  357. return ptr;
  358. }
  359. static constexpr node* newnode(node* parent, pair_type&& val)
  360. {
  361. auto* ptr = allocator_traits<allocator_type>::allocate_and_construct(std::move(val));
  362. ptr->parent = parent;
  363. return ptr;
  364. }
  365. static constexpr void delnode(node* nd)
  366. {
  367. allocator_traits<allocator_type>::deconstruct_and_deallocate(nd);
  368. }
  369. constexpr void rotateleft(node* rt)
  370. {
  371. node* nrt = rt->right;
  372. if (!rt->is_root()) {
  373. if (rt->is_left_child()) {
  374. rt->parent->left = nrt;
  375. } else {
  376. rt->parent->right = nrt;
  377. }
  378. } else {
  379. this->root = nrt;
  380. }
  381. nrt->parent = rt->parent;
  382. rt->parent = nrt;
  383. rt->right = nrt->left;
  384. nrt->left = rt;
  385. }
  386. constexpr void rotateright(node* rt)
  387. {
  388. node* nrt = rt->left;
  389. if (!rt->is_root()) {
  390. if (rt->is_left_child()) {
  391. rt->parent->left = nrt;
  392. } else {
  393. rt->parent->right = nrt;
  394. }
  395. } else {
  396. this->root = nrt;
  397. }
  398. nrt->parent = rt->parent;
  399. rt->parent = nrt;
  400. rt->left = nrt->right;
  401. nrt->right = rt;
  402. }
  403. constexpr void balance(node* nd)
  404. {
  405. if (nd->is_root()) {
  406. nd->toblack();
  407. return;
  408. }
  409. if (node::is_black(nd->parent))
  410. return;
  411. node* p = nd->parent;
  412. node* pp = nd->grandparent();
  413. node* uncle = nd->uncle();
  414. if (node::is_red(uncle)) {
  415. p->toblack();
  416. uncle->toblack();
  417. pp->tored();
  418. this->balance(pp);
  419. return;
  420. }
  421. if (p->is_left_child()) {
  422. if (nd->is_left_child()) {
  423. p->toblack();
  424. pp->tored();
  425. this->rotateright(pp);
  426. } else {
  427. this->rotateleft(p);
  428. this->balance(p);
  429. }
  430. } else {
  431. if (nd->is_right_child()) {
  432. p->toblack();
  433. pp->tored();
  434. this->rotateleft(pp);
  435. } else {
  436. this->rotateright(p);
  437. this->balance(p);
  438. }
  439. }
  440. }
  441. constexpr node* _find(const key_type& key) const
  442. {
  443. node* cur = root;
  444. for (; cur;) {
  445. if (cur->v.key == key)
  446. return cur;
  447. if (key < cur->v.key)
  448. cur = cur->left;
  449. else
  450. cur = cur->right;
  451. }
  452. return nullptr;
  453. }
  454. // this function DOES NOT dellocate the node
  455. // caller is responsible for freeing the memory
  456. // @param: nd is guaranteed to be a leaf node
  457. constexpr void _erase(node* nd)
  458. {
  459. if (nd->is_root())
  460. return;
  461. if (node::is_black(nd)) {
  462. node* p = nd->parent;
  463. node* s = nullptr;
  464. if (nd->is_left_child())
  465. s = p->right;
  466. else
  467. s = p->left;
  468. if (node::is_red(s)) {
  469. p->tored();
  470. s->toblack();
  471. if (nd->is_right_child()) {
  472. this->rotateright(p);
  473. s = p->left;
  474. } else {
  475. this->rotateleft(p);
  476. s = p->right;
  477. }
  478. }
  479. node* r = nullptr;
  480. if (node::is_red(s->left)) {
  481. r = s->left;
  482. if (s->is_left_child()) {
  483. r->toblack();
  484. s->color = p->color;
  485. this->rotateright(p);
  486. p->toblack();
  487. } else {
  488. r->color = p->color;
  489. this->rotateright(s);
  490. this->rotateleft(p);
  491. p->toblack();
  492. }
  493. } else if (node::is_red(s->right)) {
  494. r = s->right;
  495. if (s->is_left_child()) {
  496. r->color = p->color;
  497. this->rotateleft(s);
  498. this->rotateright(p);
  499. p->toblack();
  500. } else {
  501. r->toblack();
  502. s->color = p->color;
  503. this->rotateleft(p);
  504. p->toblack();
  505. }
  506. } else {
  507. s->tored();
  508. if (node::is_black(p))
  509. this->_erase(p);
  510. else
  511. p->toblack();
  512. }
  513. }
  514. }
  515. public:
  516. constexpr iterator_type end(void)
  517. {
  518. return iterator_type(nullptr);
  519. }
  520. constexpr const_iterator_type end(void) const
  521. {
  522. return const_iterator_type(nullptr);
  523. }
  524. constexpr const_iterator_type cend(void) const
  525. {
  526. return const_iterator_type(nullptr);
  527. }
  528. constexpr iterator_type begin(void)
  529. {
  530. return root ? iterator_type(root->leftmost()) : end();
  531. }
  532. constexpr const_iterator_type begin(void) const
  533. {
  534. return root ? const_iterator_type(root->leftmost()) : end();
  535. }
  536. constexpr const_iterator_type cbegin(void) const
  537. {
  538. return root ? const_iterator_type(root->leftmost()) : end();
  539. }
  540. constexpr iterator_type find(const key_type& key)
  541. {
  542. return iterator_type(_find(key));
  543. }
  544. constexpr const_iterator_type find(const key_type& key) const
  545. {
  546. return const_iterator_type(_find(key));
  547. }
  548. constexpr iterator_type insert(pair_type&& val)
  549. {
  550. node* cur = root;
  551. while (likely(cur)) {
  552. if (val.key < cur->v.key) {
  553. if (!cur->left) {
  554. node* nd = newnode(cur, std::move(val));
  555. cur->left = nd;
  556. this->balance(nd);
  557. return iterator_type(nd);
  558. } else {
  559. cur = cur->left;
  560. }
  561. } else {
  562. if (!cur->right) {
  563. node* nd = newnode(cur, std::move(val));
  564. cur->right = nd;
  565. this->balance(nd);
  566. return iterator_type(nd);
  567. } else {
  568. cur = cur->right;
  569. }
  570. }
  571. }
  572. root = newnode(nullptr, std::move(val));
  573. root->toblack();
  574. return iterator_type(root);
  575. }
  576. constexpr iterator_type erase(const iterator_type& iter)
  577. {
  578. node* nd = iter.p;
  579. if (!nd)
  580. return end();
  581. if (nd->is_root() && nd->is_leaf()) {
  582. delnode(nd);
  583. root = nullptr;
  584. return end();
  585. }
  586. node* next = nd->next();
  587. while (!nd->is_leaf()) {
  588. node* alt = nd->right ? nd->right->leftmost() : nd->left;
  589. if (nd->is_root()) {
  590. this->root = alt;
  591. }
  592. node::swap(nd, alt);
  593. }
  594. this->_erase(nd);
  595. if (nd->is_left_child())
  596. nd->parent->left = nullptr;
  597. else
  598. nd->parent->right = nullptr;
  599. delnode(nd);
  600. return iterator_type(next);
  601. }
  602. constexpr void remove(const key_type& key)
  603. {
  604. auto iter = this->find(key);
  605. if (iter != this->end())
  606. this->erase(iter);
  607. }
  608. // destroy a subtree without adjusting nodes to maintain binary tree properties
  609. constexpr void destroy(node* nd)
  610. {
  611. if (nd) {
  612. this->destroy(nd->left);
  613. this->destroy(nd->right);
  614. delnode(nd);
  615. }
  616. }
  617. explicit constexpr map(void)
  618. {
  619. }
  620. constexpr map(const map& val)
  621. {
  622. for (const auto& item : val)
  623. this->insert(item);
  624. }
  625. constexpr map(map&& val)
  626. : root(val.root)
  627. {
  628. val.root = nullptr;
  629. }
  630. constexpr map& operator=(const map& val)
  631. {
  632. this->destroy(root);
  633. for (const auto& item : val)
  634. this->insert(item);
  635. }
  636. constexpr map& operator=(map&& val)
  637. {
  638. this->destroy(root);
  639. root = val.root;
  640. val.root = nullptr;
  641. }
  642. constexpr ~map()
  643. {
  644. this->destroy(root);
  645. }
  646. };
  647. } // namespace types