Recursively flatten a list
def recursive_flatten(target_list, flat_list = None):
if flat_list is None:
flat_list = []
for item in target_list:
if not isinstance(item, list):
flat_list.append(item)
else:
recursive_flatten(item, flat_list) #recur on sublists
return flat_list
#example
L = [1, [2, [3, 4], 5], 6, [7, 8]]
flat_list = recursive_flatten(L)
print(flat_list)
The list
is the most versatile core data type in Python. It represents a sequence of ordered elements in which individual elements can be accessed by their indices.
A list
can contain elements of varying data types, such as integers, strings, and even other lists.
When a list is made up of other lists , it is referred to as a nested list.
nested_list = [[1, 2], [3, 4], [5, 6]]
print(nested_list)
Sometimes we may need to reduce a nested list to a single list that contains all of the elements from the inner sub-sequences. This is referred to as flattening the list.
There are various approaches that we can use to efficiently and conveniently flatten a nested list. The approaches are as outlined below:
- Using list Comprehension
- Using
itertools.chain()
- Using the
reduce()
Function - Using the
sum()
function - Recursive Flattening
Using list Comprehension
List comprehension is a convenient and efficient syntax for creating a new list from the elements of another iterable object. It combines the features of looping, conditional execution, and sequence building into one concise syntax.
Typically, the basic syntax of list comprehension is as shown below.
new_list = [<expression> for <item> in <target_iterable>]
We can use list comprehension with multiple for
loops to flatten a nested list.
nested_list = [['Python', 'Java'], ['PHP', 'Ruby'], ['Swift', 'C++']]
flat_list = [i for sublist in nested_list for i in sublist]
print(flat_list)
In the above example, the first for
loop in the list comprehension iterates through the sublists in the nested_list
while second for
loop iterates through the values in each sublist
, and adds them to the new list.
Using itertools.chain()
The itertools module in the standard library offers convenient and efficient tools for iterating through collections of elements such as lists.
The chain() function in the module is used to iterate through the elements of multiple iterables as if they were all connected together in a single sequence.
itertools.chain(*iterables)
Where the *iterables argument represents the arbitrary number of iterables to be 'chained'. The function creates an iterator objects which iterates through the elements of each of the given iterable.
We can use this function with the unpacking operator(*
) to create a flat list from a nested list.
#impport the chain() function
from itertools import chain
#the list to flatten
nested_list = [['Paris', 'Berlin'], ['Tokyo', 'Manilla'], ['Nairobi', 'Kigali']]
#use the chain function.
flat_iterator = chain(*nested_list)
#create a flat list from the iterator object
flat_list = list(flat_iterator)
print(flat_list)
In the above example, the unpacking operator(*
) unpacks each list in the nested list, allowing each inner list to be passed as individual arguments to the chain()
function. Since the chain()
function returns an iterator object, we created a new list with the elements from the iterator using the list()
function.
The chain()
function also has a method called from_iterable()
, which we can use to directly loop through a nested list without having to explicitly unpack individual items. Using this method, the previous example would look as shown below:
from itertools import chain
#the list to flatten
nested_list = [['Paris', 'Berlin'], ['Tokyo', 'Manilla'], ['Nairobi', 'Kigali']]
flat_iterator = chain.from_iterable(nested_list) #use the from_iterable method
#create a flat list from the iterator object
flat_list = list(flat_iterator)
print(flat_list)
Using the reduce() function
The functools module in the standard library offers useful tools for use in functional programming.
The reduce() function in the module is a convenient tool when we need to apply a function to an iterable and reduce it to a single cumulative value.
reduce(function, iterable, initial = 0)
function |
Required. A two-parameter function to be applied to the elements. |
iterable |
Required. The iterable to be reduced. |
initial |
The initial value of the accumulation. |
The reduce(
) function sequentially applies the function
to the iterable and returns a single cumulative value. Each item in the iterable is is passed to the function alongside the cumulative result.
We can use the reduce()
function to flatten a nested list by cumulatively concatenating the inner lists.
#import the reduce function
from functools import reduce
#concatenating function
def concat(lst1, lst2):
return lst1 + lst2
#the list to flatten
nested_list = [['apple', 'banana'], ['pear', 'orange'], ['mango', 'berry']]
#use the reduce functiion
flat_list = reduce(concat, nested_list, [])
print(flat_list)
In the above example, the concat()
function is passed as the function argument to the reduce function, this function is applied cumulatively to the elements of the inner list meaning that eventually all the list will be concatenated into one list. We set the initial value to be an empty list([])
and thus overriding the default which is 0
, if we hadn't done this TypeError
would have been raised.
The builtin operator module module provides the function add()
which we can use instead of defining one from scratch as we did above.
from functools import reduce
from operator import add
nested_list = [['apple', 'banana'], ['pear', 'orange'], ['mango', 'berry']]
#use the add() function
flat_list = reduce(add, nested_list, [])
print(flat_list)
Using the sum() Function
When called with a list or other iterables, the builtin function sum()
applies the +
operation on all elements of the iterable and returns the result.
The sum
() function, just like reduce()
, allows us to specify an initial value.
sum(iterable, initial = [])
When used with a nested list, a concatenation operation will happen sequentially with all the inner lists.
nested_list = [['Japan', 'India'], ['Rwanda', 'Kenya'], ['Spain', 'Italy']]
#use the sum function to flatten the list
flat_list = sum(nested_list, [])
print(flat_list)
In the above example, we specified the initial value to be an empty list this ensures that the concatenation happens as intended.
Recursive Flattening
All the previous approaches will only work when the target list has a uniform nesting depth. It is almost impossible to use the above approaches to flatten an irregularly nested list without performing a lot of redundant conditional checks.
Consider the following example, where we use the list comprehension approach with an irregularly nested list.
nested_list = [1, [2, [3, 4], 5], 6, [7, 8]]
flat_list = [i for sublist in nested_list for i in sublist]
print(flat_list)
An TypeError
exception is raised because it tries to iterate through an integer( at index 0), which is not iterable. All the previous approach would lead to the same or similar errors.
Recursive flattening may be not be as direct as other approaches but it helps us overcome the above limitation. It makes it possible to easily flatten an irregularly nested list.
We can implement recursive flattening using recursion as shown below.
recursive flattening implementation
def recursive_flatten(target_list, flat_list = None):
if flat_list is None:
flat_list = []
for item in target_list:
if not isinstance(item, list):
flat_list.append(item)
else:
recursive_flatten(item, flat_list) #recure on sublists
return flat_list
#example
nested_list = [1, [2, [3, 4], 5], 6, [7, [8, 9]]]
flat_list = recursive_flatten(nested_list)
print(flat_list)
In the above example, we implemented the recursive_flatten()
function. In the body of the function, we check if the current element is an iterable by checking if it is an instance of the abc.Iterable
class. If the element is not an iterable, we simply inserted it into the flat_list
, if otherwise it is iterable we recursively flatten it. And finally, flat_list
is returned.