I really had no idea how much new/reworked code or time this menu refactoring would require. And trying to structure it into individual blog posts is also a bit challenging given the lack of a real plan on my part.

Module chart.chart

Okay, we now have our three functions in the database/population module for getting the appropriate data for each of the three chart types. Now, let’s move on to getting the chart package to use these functions to actually display the user defined chart.

I am going to start with a new function, plot_population(), which will be called from the menu loop in the main module of the application. It will take the user supplied data as an argument and use that to plot the desired chart. I almost called it plot_chart() but decided that was too generic for what this function is supposed to do. And we may yet have a plot_chart() function that is called from plot_population(), passing it all the information needed to plot the chart. plot_chart wouldn’t have any idea what it was plotting.

chart.plot_population(chart_dtls)

I decided to let this function figure out which plot type is being requested based on the data passed in chart_dtls. For our case of only three charts seemed reasonable enough. I also didn’t bother adding a function to make that determination. Again figured for three types why bother. It then calls the appropriate population module function to get the required population data from the CSV data file.

Identifing the chart type was pretty straightforward. If there was only one country/region name specified and the age group value was “all” it was a Type 1 chart. Similarly, if the number of years to plot is “1” and the age groups value was “all” it was a Type 2 chart. If the age group value was not “all”, then it was a Type 3 chart. I might have reduced the code for the chart type tests by working in a different order. But…

Going only as far as making the chart determination and calling the appropriate database function, I ended up with the following:

def plot_population(chart_dtls):
  plot_data = [[], [], []]
  p_nms, p_yrs, p_grp = chart_dtls
  yr_list = []  # to send list of years to population module function if necessary
  if len(p_nms) == 1 and p_grp[0] == 'all':
    # type 1 chart
    plot_data[0] = p_nms
    plot_data[2] = p_grp
    # get data for each year specified by p_yrs (start yr, nbr yrs)
    strt_yr = int(p_yrs[0])
    end_yr = strt_yr + p_yrs[1]
    for yr in range(strt_yr, end_yr):
      yr_list.append(str(yr))
    #print(f"pdb.get_1cr_years({p_nms[0]}, {yr_list}) = ")
    dbg_data = pdb.get_1cr_years_all(p_nms[0], yr_list)
    #plot_data[1].append(pdb.get_1cr_years_all(p_nms[0], p_yrs))
    plot_data[1].append(dbg_data)

  elif p_yrs[1] == 1 and p_grp[0] == 'all':
    # type 2 chart
    plot_data[1] = p_yrs
    plot_data[2] = p_grp
    # get data for each country in p_nms
    dbg_data = pdb.get_crs_1yr_all(p_nms, p_yrs[0])
    plot_data[0].append(dbg_data)

  elif p_grp[0] != 'all':
    # type 3 chart
    plot_data[1] = p_yrs
    plot_data[2] = p_grp
    strt_yr = int(p_yrs[0])
    end_yr = strt_yr + p_yrs[1]
    for yr in range(strt_yr, end_yr):
      yr_list.append(str(yr))
    # get data for combination of each name in p_nms and each year specified by p_yrs
    dbg_data = pdb.get_crs_years_one(p_nms, yr_list, p_grp[0])
    plot_data[0].append(dbg_data)
    
  else:
    # wth?!
    pass

  # dev testing/debug
  print(plot_data)

Going to get a bit long, but the logic is straightforward enough that I don’t think that will be an issue. If it becomes one, well, we have options.

Module Testing

I have added my usual if __name__ block to allow testing of the code as I write it. But, this time to give me more flexibility in the testing process I am using argparse to get a command line parameter specifying which test to run from a list of possible tests. Was getting tired of changing a variable, e.g. do_tst_1 = ???, in the code each time I wanted to do a different test. Though that is still possible.

Added the following to the bottom of the chart/chart module.

if __name__ == '__main__':
  import argparse
  # check for test number on command line
  parser = argparse.ArgumentParser()
  # long name preceded by --, short by single -,
  # get it as an integer so can use to access test data array
  parser.add_argument('--do_test', '-t', type=int, help='test number (1-3)')
  
  args = parser.parse_args()
  if args.do_test >= 1 and args.do_test <= 3:
    do_tst = args.do_test
  else:
    do_tst = 1

  tst_deets = [
      [],
      [['Zimbabwe'], ['2005', 2], ['all']],
      [['Venezuela (Bolivarian Republic of)', 'Chile'], ['2010', 1], ['all']],
      [['Zimbabwe', 'Mozambique'], ['2010', 5], ['65-69']]
      ]
  print(f"\nchart data: {tst_deets[do_tst]}\n")
  plot_population(tst_deets[do_tst])
  print('\n')

So, executing python r:/learn/py_play/population/chart/chart.py -t 2 in an a suitably configured Anaconda PowerShell produced the following:

(base-3.8) PS R:\learn\py_play> python r:/learn/py_play/population/chart/chart.py -t 2

chart data: [['Venezuela (Bolivarian Republic of)', 'Chile'], ['2010', 1], ['all']]

[[{
'Chile': [1224.229, 1215.911, 1327.57, 1462.566, 1422.594, 1317.564, 1280.183, 1239.245, 1220.229, 1166.694, 1034.5, 884.093, 669.755, 519.24, 401.13, 310.239, 218.714, 100.888, 36.256, 9.329, 1.602],
'Venezuela (Bolivarian Republic of)': [2900.078, 2847.209, 2758.383, 2698.861, 2619.602, 2368.262, 2163.254, 1922.294, 1795.301, 1587.041, 1343.277, 1046.425, 806.459, 597.682, 421.854, 274.75, 162.923, 82.601, 32.885, 9.213, 1.588]
}], ['2010', 1], ['all']]

Sorting Plot Data Structure

There are similarities for each case, but also basic differences, like plot title, possible legend values (countries or years), axis labels, x-axis values (age groups or years), etc. I propose to sort those in the the plot_population function and then send a suitable data structure to the function that actually generates and displays the plot. I expect this will make the function rather long. But, for now that’s how I am going to do it. Once I have something working, I will look at the need to use additional, more specific functions for each plot type.

When I started working on the above, I realized it was going to get rather crazy. So, I have decided to simply write a plotting function for each chart type/situation. Each function will get the appropriate data, then sort everything else out before plotting the chart.

I am also going to write a small function that sorts out the bar width for adjacent bar situations, x-axis locations for the bars and x-tick locations for charts that need that information. And, we no longer, for the first two chart types, have the x-label information availabe in the data structure obtained from the CSV file functions. So, I will also add a small function to provide those labels. For now, I will do so in the chart module. But, it likely should be in the database/population module.

chart.chart.get_xticks(nbr_lbls=20, nbr_bars=5)

Let’s start with those first two. I’ll start with get_xticks(nbr_lbls=21, nbr_bars=1). It will return three values: the bar width for plotting (an int), bar locations (a list of lists of int/float) for cases where we have adjacent bar groups being plotted (e.g. type 1 chart with multiple years) and the x-tick label locations (a list of floats) for the multiple bar chart situation. There are 21 age groups, hence the default value for the nbr_lbls parameter.

I got the method for determining the optimal bar width from this stackoverflow post: How to plot bar graphs with same X coordinates side by side.

def get_xticks(nbr_lbls=21, nbr_bars=1):
  x_ticks = []
  b_width = 0.4

  lbl_x = [[] for _ in range(nbr_bars+1)] # store positions for each bar, indexed at 1 for first bar
  lbl_x[1] = list(np.arange(nbr_lbls))  # the first bar locations
  # Calculate optimal width
  # How to plot bar graphs with same X coordinates side by side
  # https://stackoverflow.com/questions/10369681/how-to-plot-bar-graphs-with-same-x-coordinates-side-by-side-dodged
  b_width = np.min(np.diff(lbl_x[1]))/(nbr_bars + 1.0)

  # sort x position for all other bars
  for i in range(2, nbr_bars+1):
    lbl_x[i] = [x + b_width for x in lbl_x[i-1]]

  x_ticks = [x + ((nbr_bars-1)/2*b_width) for x in lbl_x[1]]

  return b_width, lbl_x, x_ticks

To help with my testing, I modified plot_population() to return the plot data rather than printing it. And, I added another argument for testing. It allows me to change the number of years when testing Type 1 or Type 3 plots. Then I reworked my testing code to facilitate the various requirements for each test case. Lots of duplicate code, sorry.

if __name__ == '__main__':
  import argparse
  # check for test number on command line
  parser = argparse.ArgumentParser()
  # long name preceded by --, short by single -, get it as an integer so can use to access test data array
  parser.add_argument('--do_test', '-t', type=int, help='test number (1-3)')
  parser.add_argument('--nbr_yrs', '-y', type=int, default=2, help='number of years to plot, default is 2')

  args = parser.parse_args()
  if args.do_test >= 1 and args.do_test <= 3:
    do_tst = args.do_test
  else:
    do_tst = 1

  tst_deets = [
      [],
      [['Zimbabwe'], ['2005', 2], ['all']],
      [['Venezuela (Bolivarian Republic of)', 'Chile'], ['2010', 1], ['all']],
      [['Zimbabwe', 'Mozambique'], ['2010', 5], ['65-69']]
      ]

  if do_tst == 1:
    n_lbls = 21
    print(f"args.nbr_bars {args.nbr_bars}\n")
    if args.nbr_bars != 2:
      tst_deets[do_tst][1][1] = args.nbr_bars
    n_bars = tst_deets[do_tst][1][1]
    print(f"\nchart data: {tst_deets[do_tst]}\n")
    cnms, years, grps = plot_population(tst_deets[do_tst])
    print(cnms, years, grps)
    print('\n')

  elif do_tst == 2:
    n_lbls = 21
    n_bars = len(tst_deets[do_tst][0])
    print(f"\nchart data: {tst_deets[do_tst]}\n")
    cnms, years, grps = plot_population(tst_deets[do_tst])
    print(cnms, years, grps)
    print('\n')

  elif do_tst == 3:
    if args.nbr_yrs > 5:
      tst_deets[do_tst][1][1] = args.nbr_yrs
    n_lbls = tst_deets[do_tst][1][1]
    n_bars = len(tst_deets[do_tst][0])
    print(f"\nchart data: {tst_deets[do_tst]}\n")
    cnms, years, grps = plot_population(tst_deets[do_tst])
    print(cnms, years, grps)
    print('\n')

  b_wd, bar_x, ticks = get_xticks(nbr_lbls=n_lbls, nbr_bars=n_bars)
  print(f"bar width = {b_wd}\n\nbar positions: {bar_x}\n\nx-ticks: {ticks}\n")

Time to commit our changes. Then we can get to work on get_agrp_lbls(). We want it to return a list of the text for all 21 age groups. Simple, right?

X-Axis Labels for Charts

I am going to write a function to generate a list of age groups to use for the x-tick labels for plots of Type 1 and 2. The function will return a list of the correct text label for each of the age groups in the CSV File, in the proper order. For Type 3 plots the x-tick labels are the years being plotted and we can get those from the population data structure — the data for each country is a dictionary keyed on the years being plotted.

Give it a try before you look at how I do it. Perfect place for loop. You did recall that the last age group was a special case. We will print the labels in our test section to make sure we are getting something reasonable.

def get_agrp_lbls():
  ag_lbls = []
  for i in range(0, 100, 5):
    ag_end = str(i + 4)
    ag_lbls.append(str(i) + "-" + ag_end)
  ag_lbls.append("100+")
  return ag_lbls

Now let’s turn our attention to generating the x-axis labels for the Type 3 chart. These are the years being plotted, the grouped bars are the appropriate population for each country being plotted. The years to be plotted are available in the plot data structure. The first element is a list of dictionaries keyed on the country names to be plotted. The value for each of these is another dictionary of the population values keyed on the years to be plotted. So the keys() of one of those nested dictionaries would be the list of x-tick labels we are after. Let’s add the code to do that to our test section for the Type 3 chart. Then print the list.

I had a bit of glitch during my first attempt.

Traceback (most recent call last):
  File "r:/learn/py_play/population/chart/chart.py", line 170, in <module>
    labels = cnms[0][cnms[0].keys()[0]].keys()
TypeError: 'dict_keys' object is not subscriptable

The solution for my approach was to convert the value returned by keys() into a list.

    labels = list(cnms[0][list(cnms[0].keys())[0]].keys())
    print(labels)

Not sure I need the final list conversion. Will see when we actually start to plot the charts.

Bye For Now

Well, you know, I think that’s enough for this post. In the next one, accelerated schedule, I will move the test code into three new functions, one for each plot type. These functions will be called from plot_population() to generate the plots. We will add the plotting code next time as well.

Resources