|
@@ -16,6 +16,8 @@ ListNode* list_get(ListNode** head) {
|
|
|
ListNode* node = *head;
|
|
|
if (node) {
|
|
|
*head = node->next;
|
|
|
+ if (*head)
|
|
|
+ (*head)->prev = nullptr;
|
|
|
|
|
|
node->next = nullptr;
|
|
|
node->prev = nullptr;
|
|
@@ -25,10 +27,18 @@ ListNode* list_get(ListNode** head) {
|
|
|
|
|
|
template <typename ListNode>
|
|
|
void list_remove(ListNode** head, ListNode* node) {
|
|
|
- if (node->prev)
|
|
|
- node->prev->next = node->next;
|
|
|
- else
|
|
|
+ if (node == *head) {
|
|
|
+ assert(!node->prev);
|
|
|
+
|
|
|
*head = node->next;
|
|
|
+ if (*head)
|
|
|
+ (*head)->prev = nullptr;
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ assert(node->prev);
|
|
|
+
|
|
|
+ node->prev->next = node->next;
|
|
|
+ }
|
|
|
|
|
|
if (node->next)
|
|
|
node->next->prev = node->prev;
|