Monday, September 2, 2024

doodle scatterplot python data visualization

 


this is a work in progress scatter plot tool in python. it supports plotting multiple scatter plots. it also should support overlaying multiple lines on a single plot by passing a list of data objects like:
sp = ScatterPlot(data=[[spd1, spd2]], nrows=1, ncols=1)
sp.doIt()
it has basic options like setting color or choosing to have lines or not. adding support for axis labels, titles etc should be able to be added by modifying the ScatterPlotData class (maybe even possibly extending it using inheritance). (there may be bugs so please modify/use at your own risk.)
#ScatterPlot.py
#
#to test:
#python ScatterPlot.py

import logging
import numpy as np
import matplotlib
from matplotlib import pyplot

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)

class ScatterPlot(object):
    """
    supporting scatter plotting with different axis for each plot
    ???using same y axis on every plot
    ???using same x axis on every plot
    """
    def __init__(self, data=[], nrows=1, ncols=1):
        self._data = data or []
        self._nrows = nrows
        self._ncols = ncols
        
        self._fontSize = 12
        
        self._validateData()
        
    def setFontSize(self, fontSize=12):
        """set font size to use for all plots
        """
        self._fontSize = fontSize
        
    def _validateData(self):
        for dat in self._data:
            if isinstance(dat, list):
                for d in dat:
                    if not isinstance(d, ScatterPlotData):
                        logger.warning("requires {ScatterPlotData} object")
                        raise TypeError()
            else:               
                if not isinstance(dat, ScatterPlotData):
                    logger.warning("requires {ScatterPlotData} object")
                    raise TypeError()
    
    def doIt(self):
        """plots all plots using given rows and columns for figure
        """
        data = self._data
        nrows = self._nrows
        ncols = self._ncols
        
        numPlots = len(data)
        if not numPlots:
            return False

        fig, ax = pyplot.subplots(nrows=nrows, ncols=ncols)
        plotIndex = 0 #will be used to count through all plots to make
        
        if isinstance(ax, np.ndarray):
            for row in ax:
                #print(type(row))
                if issubclass(type(row), matplotlib.axes.SubplotBase):
                    plotObj = data[plotIndex]
                    eval("row.{cmd}".format(cmd=plotObj.getPyplotCmd())) 
                    #row.scatter(plotObj.getX(), plotObj.getY(), color=plotObj.getColor())
                    plotIndex += 1
                else:
                    for col in row:
                        plotObj = data[plotIndex]
                        if isinstance(plotObj, list):
                            #handle multiple overlayed plots
                            for plotOb in plotObj:
                                eval("col.{cmd}".format(cmd=plotOb.getPyplotCmd())) 
                                #col.scatter(plotOb.getX(), plotOb.getY(), color=plotOb.getColor())
                        else:
                            #single scatter plot 
                            eval("col.{cmd}".format(cmd=plotObj.getPyplotCmd()))                  
                            #col.scatter(plotObj.getX(), plotObj.getY(), color=plotObj.getColor())
                        plotIndex += 1
        else:
            #only one plot to make
            plotObj = data[plotIndex]
            if isinstance(plotObj, list):
                for plotOb in plotObj:
                    eval("ax.{cmd}".format(cmd=plotOb.getPyplotCmd()))                     
                    #ax.scatter(plotOb.getX(), plotOb.getY(), color=plotOb.getColor())               
            else:
                pcmd = plotObj.getPyplotCmd()
                #logger.info("pcmd:{}".format(pcmd))
                eval("ax.{cmd}".format(cmd=pcmd))
                #ax.scatter(plotObj.getX(), plotObj.getY(), color=plotObj.getColor())   
        fig.tight_layout()
        
        if self._fontSize:
            pyplot.rcParams.update({'font.size':self._fontSize})
        
        
        pyplot.show()

        return True

class ScatterPlotData(object):
    """
    supporting color
    ?? title
    ?? xlabel, ylabel
    ?? symbol type
    ?? xrange, yrange for axis
    """
    def __init__(self, x=None, y=None, color="grey"):
        self._x = x
        self._y = y
        self._color = color
        self._pyplotCmd = None
        self._hasLine = False

        #default command
        self._pyplotCmd = "plot({x}, {y}, 'ro-', color='{col}')".format(x=self._x, 
                                                                            y=self._y,
                                                                            col=self._color)

    def getX(self):
        return self._x
        
    def getY(self):
        return self._y
    
    def getColor(self):
        return self._color
    
    def setColor(self, color):
        self._color = color
    
    def setHasLine(self, val):
        self._hasLine = val

    def getPyplotCmd(self):
        return self._pyplotCmd

    def setPyplotCmd(self, cmd):
        self._pyplotCmd = cmd

    def initCmd(self):
        """handles configuring plot command data
        """
        #working on supporting multiple kinds of plots
        if self._hasLine:
            self.setPyplotCmd("plot({x}, {y}, 'ro-', color='{col}')".format(x=self._x, 
                                                                            y=self._y,
                                                                            col=self._color))
        else:
            #self.setPyplotCmd("scatter({x}, {y})".format(x=self._x, y=self._y))       
            self.setPyplotCmd("scatter({x}, {y}, color='{col}')".format(x=self._x, 
                                                                            y=self._y,
                                                                            col=self._color))
    @classmethod
    def createByList(cls, x=None, y=None):
        if (not isinstance(x, list)) or (not isinstance(y, list)):
            raise TypeError("requres {list} data")
        return cls(x=x, y=y)
    
    @classmethod
    def createByNumpyList(cls, x=None, y=None):
        if (not isinstance(x, np.ndarray)) or (not isinstance(y, np.ndarray)):
            raise TypeError("requres {np.ndarray} data")
        return cls(x=list(x), y=list(y)) #need np arrays converted to list for plotting
    
    
if __name__ == "__main__":
    print("yaay")
    
    #spd2 = ScatterPlotData.createByNumpyList(x=np.array([0, 1, 2]), y=np.array([8, 12, 20]))
    #spd2.setColor("blue")

    spd1 = ScatterPlotData.createByList(x=[0, 1, 2], y=[2, 4, 8])
    spd1.setColor("blue")
    spd1.setHasLine(False)
    spd1.initCmd() #this is needed to initialize plot command. without this line get defaults

    spd2 = ScatterPlotData.createByNumpyList(x=np.array([0, 1, 2]), y=np.array([8, 12, 20]))
    spd2.setColor("green")
    spd2.setHasLine(True)
    spd2.initCmd()

    #sp = ScatterPlot(data=[spd1], nrows=1, ncols=1)
    #sp.doIt()
    #print(spd1.getPyplotCmd())
    #print(spd2.getPyplotCmd())

    #sp = ScatterPlot(data=[spd2], nrows=1, ncols=1)
    #sp.doIt()

    sp = ScatterPlot(data=[spd2 for i in range(0,8)]+[spd1 for i in range(0,8)], nrows=4, ncols=4)
    sp.doIt()

    #sp = ScatterPlot(data=[spd1, spd2], nrows=1, ncols=2)
    #sp.setFontSize(14)
    #sp.doIt()
  
    #sp = ScatterPlot(data=[[spd1, spd2]], nrows=1, ncols=1)
    #sp.doIt() 
 

Happy Sketching!

Inspired by,

Nils Gehlenborg