Discussion 7: Linked Lists, Mutable Trees
Trees (Class)
Tree
instance has two instance attributes:
label
is the value stored at the root of the tree.branches
is a list ofTree
instances that hold the labels in the rest of the tree.
The Tree
class (with its __repr__
and __str__
methods omitted) is defined as:
class Tree:
"""
>>> t = Tree(3, [Tree(2, [Tree(5)]), Tree(4)])
>>> t.label
3
>>> t.branches[0].label
2
>>> t.branches[1].is_leaf()
True
"""
def __init__(self, label, branches=[]):
for b in branches:
assert isinstance(b, Tree)
self.label = label
self.branches = list(branches)
def is_leaf(self):
return not self.branches
To construct a Tree
instance from a label x
(any value) and a list of branches bs
(a list of Tree
instances) and give it the name t
, write t = Tree(x, bs)
.
For a tree t
:
- Its root label can be any value, and
t.label
evaluates to it. - Its branches are always
Tree
instances, andt.branches
evaluates to the list of its branches. t.is_leaf()
returnsTrue
ift.branches
is empty andFalse
otherwise.- To construct a leaf with label
x
, writeTree(x)
.
Displaying a tree t
:
repr(t)
returns a Python expression that evaluates to an equivalent tree.str(t)
returns one line for each label indented once more than its parent with children below their parents.
>>> t = Tree(3, [Tree(1, [Tree(4), Tree(1)]), Tree(5, [Tree(9)])])
>>> t # displays the contents of repr(t)
Tree(3, [Tree(1, [Tree(4), Tree(1)]), Tree(5, [Tree(9)])])
>>> print(t) # displays the contents of str(t)
3
1
4
1
5
9
Changing (also known as mutating) a tree t
:
t.label = y
changes the root label oft
toy
(any value).t.branches = ns
changes the branches oft
tons
(a list ofTree
instances).- Mutation of
t.branches
will changet
. For example,t.branches.append(Tree(y))
will add a leaf labeledy
as the right-most branch. - Mutation of any branch in
t
will changet
. For example,t.branches[0].label = y
will change the root label of the left-most branch toy
.
>>> t.label = 3.0
>>> t.branches[1].label = 5.0
>>> t.branches.append(Tree(2, [Tree(6)]))
>>> print(t)
3.0
1
4
1
5.0
9
2
6
Here is a summary of the differences between the tree data abstraction implemented as a functional abstraction vs. implemented as a class:
- | Tree constructor and selector functions | Tree class |
---|---|---|
Constructing a tree | To construct a tree given a label and a list of branches , we call tree(label, branches) |
To construct a tree object given a label and a list of branches , we call Tree(label, branches) (which calls the Tree.__init__ method). |
Label and branches | To get the label or branches of a tree t , we call label(t) or branches(t) respectively |
To get the label or branches of a tree t , we access the instance attributes t.label or t.branches respectively. |
Mutability | The functional tree data abstraction is immutable (without violating its abstraction barrier) because we cannot assign values to call expressions | The label and branches attributes of a Tree instance can be reassigned, mutating the tree. |
Checking if a tree is a leaf | To check whether a tree t is a leaf, we call the function is_leaf(t) |
To check whether a tree t is a leaf, we call the method t.is_leaf() . This method can only be called on Tree objects. |
Q1: Is BST
Write a function is_bst
, which takes a Tree t
and returns True
if, and
only if, t
is a valid binary search tree, which means that:
- Each node has at most two children (a leaf is automatically a valid binary search tree)
- The children are valid binary search trees
- For every node, the entries in that node's left child are less than or equal to the label of the node
- For every node, the entries in that node's right child are greater than the label of the node
An example of a BST is:
Your Answer Run in 61A CodeNote: If a node has only one child, that child could be considered either the left or right child. You should take this into consideration.
Hint: It may be helpful to write helper functions
bst_min
andbst_max
that return the minimum and maximum, respectively, of a Tree if it is a valid binary search tree.
def is_bst(t):
"""Returns True if the Tree t has the structure of a valid BST.
>>> t1 = Tree(6, [Tree(2, [Tree(1), Tree(4)]), Tree(7, [Tree(7), Tree(8)])])
>>> is_bst(t1)
True
>>> t2 = Tree(8, [Tree(2, [Tree(9), Tree(1)]), Tree(3, [Tree(6)]), Tree(5)])
>>> is_bst(t2)
False
>>> t3 = Tree(6, [Tree(2, [Tree(4), Tree(1)]), Tree(7, [Tree(7), Tree(8)])])
>>> is_bst(t3)
False
>>> t4 = Tree(1, [Tree(2, [Tree(3, [Tree(4)])])])
>>> is_bst(t4)
True
>>> t5 = Tree(1, [Tree(0, [Tree(-1, [Tree(-2)])])])
>>> is_bst(t5)
True
>>> t6 = Tree(1, [Tree(4, [Tree(2, [Tree(3)])])])
>>> is_bst(t6)
True
>>> t7 = Tree(2, [Tree(1, [Tree(5)]), Tree(4)])
>>> is_bst(t7)
False
"""
def bst_min(t):
"""Returns the min of t, if t has the structure of a valid BST."""
if t.is_leaf():
return t.label
return min(t.label, bst_min(t.branches[0]))
def bst_max(t):
"""Returns the max of t, if t has the structure of a valid BST."""
if t.is_leaf():
return t.label
return max(t.label, bst_max(t.branches[-1]))
if t.is_leaf():
return True
if len(t.branches) == 1:
c = t.branches[0]
return is_bst(c) and (bst_max(c) <= t.label or bst_min(c) > t.label)
elif len(t.branches) == 2:
c1, c2 = t.branches
valid_branches = is_bst(c1) and is_bst(c2)
return valid_branches and bst_max(c1) <= t.label and bst_min(c2) > t.label
else:
return False
Q2: Prune Small
Complete the function prune_small
that takes in a Tree
t
and a
number n
and prunes t
mutatively. If t
or any of its branches
has more than n
branches, the n
branches with the smallest labels
should be kept and any other branches should be pruned, or removed,
from the tree.
def prune_small(t, n):
"""Prune the tree mutatively, keeping only the n branches
of each node with the smallest labels.
>>> t1 = Tree(6)
>>> prune_small(t1, 2)
>>> t1
Tree(6)
>>> t2 = Tree(6, [Tree(3), Tree(4)])
>>> prune_small(t2, 1)
>>> t2
Tree(6, [Tree(3)])
>>> t3 = Tree(6, [Tree(1), Tree(3, [Tree(1), Tree(2), Tree(3)]), Tree(5, [Tree(3), Tree(4)])])
>>> prune_small(t3, 2)
>>> t3
Tree(6, [Tree(1), Tree(3, [Tree(1), Tree(2)])])
"""
while len(t.branches) > n:
largest = max(t.branches, key=lambda x: x.label)
t.branches.remove(largest)
for b in t.branches:
prune_small(b, n)
Linked Lists
A linked list is a data structure for storing a sequence of values. It is more
efficient than a regular built-in list for certain operations, such as inserting
a value in the middle of a long list. Linked lists are not built in, and so we
define a class called Link
to represent them.
A linked list is either a Link
instance or Link.empty
(which represents an empty linked list).
A instance of Link
has two instance attributes, first
and rest
.
The rest
attribute of a Link
instance should always be a linked list: either
another Link
instance or Link.empty
. It SHOULD NEVER be None
.
To check if a linked list is empty, compare it to Link.empty
. Since there is only
ever one empty list, we can use is
to compare, but ==
would work too.
def is_empty(s):
"""Return whether linked list s is empty."""
return s is Link.empty:
You can mutate a Link
object s
in two ways:
- Change the first element with
s.first = ...
- Change the rest of the elements with
s.rest = ...
You can make a new Link
object by calling Link
:
Link(4)
makes a linked list of length 1 containing 4.Link(4, s)
makes a linked list that starts with 4 followed by the elements of linked lists
.
Here is the implementation of the Link
class:
class Link:
"""A linked list is either a Link object or Link.empty
>>> s = Link(3, Link(4, Link(5)))
>>> s.rest
Link(4, Link(5))
>>> s.rest.rest.rest is Link.empty
True
>>> s.rest.first * 2
8
>>> print(s)
<3 4 5>
"""
empty = ()
def __init__(self, first, rest=empty):
assert rest is Link.empty or isinstance(rest, Link)
self.first = first
self.rest = rest
def __repr__(self):
if self.rest:
rest_repr = ', ' + repr(self.rest)
else:
rest_repr = ''
return 'Link(' + repr(self.first) + rest_repr + ')'
def __str__(self):
string = '<'
while self.rest is not Link.empty:
string += str(self.first) + ' '
self = self.rest
return string + str(self.first) + '>'
Q3: Sum Two Ways
Implement both sum_rec
and sum_iter
. Each one takes a linked list of numbers
s
and returns the sum of its elements. Use recursion to implement sum_rec
.
Don't use recursion to implement sum_iter
; use a while
loop instead.
def sum_rec(s):
"""
Returns the sum of the elements in s.
>>> a = Link(1, Link(6, Link(7)))
>>> sum_rec(a)
14
>>> sum_rec(Link.empty)
0
"""
# Use a recursive call to sum_rec
if s == Link.empty:
return 0
return s.first + sum_rec(s.rest)
def sum_iter(s):
"""
Returns the sum of the elements in s.
>>> a = Link(1, Link(6, Link(7)))
>>> sum_iter(a)
14
>>> sum_iter(Link.empty)
0
"""
# Don't call sum_rec or sum_iter
total = 0
while s != Link.empty:
total, s = total + s.first, s.rest
return total
s.first
to the sum of the elements in s.rest
. Your base case condition
should be s is Link.empty
so that you're checking whether s
is empty before
ever evaluating s.first
or s.rest
.
total
, then repeatedly (in a while
loop) add
s.first
to total
and set s = s.rest
to advance through the linked list, as
long as s is not Link.empty
.
Discussion time: When adding up numbers, the intermediate sums depend on the order.
(1 + 3) + 5 and 1 + (3 + 5) both equal 9, but the first one makes 4 along
the way while the second makes 8 along the way. For the same linked list, will
sum_rec
and sum_iter
both make the same intermediate sums along the way?
Q4: Overlap
Implement overlap
, which takes two linked lists of numbers called s
and t
that are sorted in increasing order and have no repeated elements within each
list. It returns the count of how many numbers appear in both lists.
This can be done in linear time in the combined length of s
and t
by
always advancing forward in the linked list whose first element is smallest
until both first elements are equal (add one to the count and advance both) or
one list is empty (time to return). Here's a
lecture video clip
about this (but the video uses Python lists instead of linked lists).
Take a vote to decide whether to use recursion or iteration. Either way works (and the solutions are about the same complexity/difficulty).
Your Answer Run in 61A Codedef overlap(s, t):
"""For increasing s and t, count the numbers that appear in both.
>>> a = Link(3, Link(4, Link(6, Link(7, Link(9, Link(10))))))
>>> b = Link(1, Link(3, Link(5, Link(7, Link(8)))))
>>> overlap(a, b) # 3 and 7
2
>>> overlap(a.rest, b) # just 7
1
>>> overlap(Link(0, a), Link(0, b))
3
"""
if s is Link.empty or t is Link.empty:
return 0
if s.first == t.first:
return 1 + overlap(s.rest, t.rest)
elif s.first < t.first:
return overlap(s.rest, t)
elif s.first > t.first:
return overlap(s, t.rest)
def overlap_iterative(s, t):
"""For increasing s and t, count the numbers that appear in both.
>>> a = Link(3, Link(4, Link(6, Link(7, Link(9, Link(10))))))
>>> b = Link(1, Link(3, Link(5, Link(7, Link(8)))))
>>> overlap(a, b) # 3 and 7
2
>>> overlap(a.rest, b) # just 7
1
>>> overlap(Link(0, a), Link(0, b))
3
"""
res = 0
while s is not Link.empty and t is not Link.empty:
if s.first == t.first:
res += 1
s = s.rest
t = t.rest
elif s.first < t.first:
s = s.rest
else:
t = t.rest
return res
if s is Link.empty or t is Link.empty:
return 0
if s.first == t.first:
return __________________
elif s.first < t.first:
return __________________
elif s.first > t.first:
return __________________
k = 0
while s is not Link.empty and t is not Link.empty:
if s.first == t.first:
__________________
elif s.first < t.first:
__________________
elif s.first > t.first:
__________________
return k
Q5: Duplicate Link
Write a function duplicate_link
that takes in a linked list s
and a value
. duplicate_link
will mutate s
such that if there is a linked list node that has a first
equal to value
, that node will be duplicated. Note that you should be mutating the original linked list s
; you will need to create new Link
s, but you should not be returning a new linked list.
Your Answer Run in 61A CodeNote: In order to insert a link into a linked list, you need to modify the
.rest
of certain links. We encourage you to draw out a doctest to visualize!
def duplicate_link(s, val):
"""Mutates s so that each element equal to val is followed by another val.
>>> x = Link(5, Link(4, Link(5)))
>>> duplicate_link(x, 5)
>>> x
Link(5, Link(5, Link(4, Link(5, Link(5)))))
>>> y = Link(2, Link(4, Link(6, Link(8))))
>>> duplicate_link(y, 10)
>>> y
Link(2, Link(4, Link(6, Link(8))))
>>> z = Link(1, Link(2, (Link(2, Link(3)))))
>>> duplicate_link(z, 2) # ensures that back to back links with val are both duplicated
>>> z
Link(1, Link(2, Link(2, Link(2, Link(2, Link(3))))))
"""
if s is Link.empty:
return
elif s.first == val:
remaining = s.rest
s.rest = Link(val, remaining)
duplicate_link(remaining, val)
else:
duplicate_link(s.rest, val)
Submit Attendance
You're done! Excellent work this week. Please be sure to ask your section TA for the attendance form link and fill it out for credit. (one submission per person per section).
Extra Challenge
This last question is similar in complexity to an A+ question on an exam. Feel free to skip it, but it's a fun one, so try it if you have time.
Q6: Decimal Expansion
Definition. The decimal expansion of a fraction n/d
with n < d
is an
infinite sequence of digits starting with the 0 before the decimal point and
followed by digits that represent the tenths, hundredths, and thousands place
(and so on) of the number n/d
. E.g., the decimal expansion of 2/3 is a zero
followed by an infinite sequence of 6's: 0.6666666....
Implement divide
, which takes positive integers n
and d
with n < d
. It
returns a linked list with a cycle containing the digits of the infinite decimal
expansion of n/d
. The provided display
function prints the first k
digits
after the decimal point.
For example, 1/22 would be represented as x
below:
>>> 1/22
0.045454545454545456
>>> x = Link(0, Link(0, Link(4, Link(5))))
>>> x.rest.rest.rest.rest = x.rest.rest
>>> display(x, 20)
0.04545454545454545454...
Your Answer
Run in 61A Code
def display(s, k=10):
"""Print the first k digits of infinite linked list s as a decimal.
>>> s = Link(0, Link(8, Link(3)))
>>> s.rest.rest.rest = s.rest.rest
>>> display(s)
0.8333333333...
"""
assert s.first == 0, f'{s.first} is not 0'
digits = f'{s.first}.'
s = s.rest
for _ in range(k):
assert s.first >= 0 and s.first < 10, f'{s.first} is not a digit'
digits += str(s.first)
s = s.rest
print(digits + '...')
def divide(n, d):
"""Return a linked list with a cycle containing the digits of n/d.
>>> display(divide(5, 6))
0.8333333333...
>>> display(divide(2, 7))
0.2857142857...
>>> display(divide(1, 2500))
0.0004000000...
>>> display(divide(3, 11))
0.2727272727...
>>> display(divide(3, 99))
0.0303030303...
>>> display(divide(2, 31), 50)
0.06451612903225806451612903225806451612903225806451...
"""
assert n > 0 and n < d
result = Link(0) # The zero before the decimal point
cache = {}
tail = result
while n not in cache:
q, r = 10 * n // d, 10 * n % d
tail.rest = Link(q)
tail = tail.rest
cache[n] = tail
n = r
tail.rest = cache[n]
return result
while
statement:
>>> q, r = 10 * n // d, 10 * n % d
>>> tail.rest = Link(q)
>>> tail = tail.rest
>>> n = r
While constructing the decimal expansion, store the tail
for each n
in a
dictionary keyed by n
. When some n
appears a second time, instead of
constructing a new Link
, set its original link as the rest of the previous
link. That will form a cycle of the appropriate length.