this is a work in progress scatter plot tool in python. it supports plotting multiple scatter plots. it also should support overlaying multiple lines on a single plot by passing a list of data objects like:
sp = ScatterPlot(data=[[spd1, spd2]], nrows=1, ncols=1) sp.doIt()it has basic options like setting color or choosing to have lines or not. adding support for axis labels, titles etc should be able to be added by modifying the ScatterPlotData class (maybe even possibly extending it using inheritance). (there may be bugs so please modify/use at your own risk.)
#ScatterPlot.py # #to test: #python ScatterPlot.py import logging import numpy as np import matplotlib from matplotlib import pyplot logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) class ScatterPlot(object): """ supporting scatter plotting with different axis for each plot ???using same y axis on every plot ???using same x axis on every plot """ def __init__(self, data=[], nrows=1, ncols=1): self._data = data or [] self._nrows = nrows self._ncols = ncols self._fontSize = 12 self._validateData() def setFontSize(self, fontSize=12): """set font size to use for all plots """ self._fontSize = fontSize def _validateData(self): for dat in self._data: if isinstance(dat, list): for d in dat: if not isinstance(d, ScatterPlotData): logger.warning("requires {ScatterPlotData} object") raise TypeError() else: if not isinstance(dat, ScatterPlotData): logger.warning("requires {ScatterPlotData} object") raise TypeError() def doIt(self): """plots all plots using given rows and columns for figure """ data = self._data nrows = self._nrows ncols = self._ncols numPlots = len(data) if not numPlots: return False fig, ax = pyplot.subplots(nrows=nrows, ncols=ncols) plotIndex = 0 #will be used to count through all plots to make if isinstance(ax, np.ndarray): for row in ax: #print(type(row)) if issubclass(type(row), matplotlib.axes.SubplotBase): plotObj = data[plotIndex] eval("row.{cmd}".format(cmd=plotObj.getPyplotCmd())) #row.scatter(plotObj.getX(), plotObj.getY(), color=plotObj.getColor()) plotIndex += 1 else: for col in row: plotObj = data[plotIndex] if isinstance(plotObj, list): #handle multiple overlayed plots for plotOb in plotObj: eval("col.{cmd}".format(cmd=plotOb.getPyplotCmd())) #col.scatter(plotOb.getX(), plotOb.getY(), color=plotOb.getColor()) else: #single scatter plot eval("col.{cmd}".format(cmd=plotObj.getPyplotCmd())) #col.scatter(plotObj.getX(), plotObj.getY(), color=plotObj.getColor()) plotIndex += 1 else: #only one plot to make plotObj = data[plotIndex] if isinstance(plotObj, list): for plotOb in plotObj: eval("ax.{cmd}".format(cmd=plotOb.getPyplotCmd())) #ax.scatter(plotOb.getX(), plotOb.getY(), color=plotOb.getColor()) else: pcmd = plotObj.getPyplotCmd() #logger.info("pcmd:{}".format(pcmd)) eval("ax.{cmd}".format(cmd=pcmd)) #ax.scatter(plotObj.getX(), plotObj.getY(), color=plotObj.getColor()) fig.tight_layout() if self._fontSize: pyplot.rcParams.update({'font.size':self._fontSize}) pyplot.show() return True class ScatterPlotData(object): """ supporting color ?? title ?? xlabel, ylabel ?? symbol type ?? xrange, yrange for axis """ def __init__(self, x=None, y=None, color="grey"): self._x = x self._y = y self._color = color self._pyplotCmd = None self._hasLine = False #default command self._pyplotCmd = "plot({x}, {y}, 'ro-', color='{col}')".format(x=self._x, y=self._y, col=self._color) def getX(self): return self._x def getY(self): return self._y def getColor(self): return self._color def setColor(self, color): self._color = color def setHasLine(self, val): self._hasLine = val def getPyplotCmd(self): return self._pyplotCmd def setPyplotCmd(self, cmd): self._pyplotCmd = cmd def initCmd(self): """handles configuring plot command data """ #working on supporting multiple kinds of plots if self._hasLine: self.setPyplotCmd("plot({x}, {y}, 'ro-', color='{col}')".format(x=self._x, y=self._y, col=self._color)) else: #self.setPyplotCmd("scatter({x}, {y})".format(x=self._x, y=self._y)) self.setPyplotCmd("scatter({x}, {y}, color='{col}')".format(x=self._x, y=self._y, col=self._color)) @classmethod def createByList(cls, x=None, y=None): if (not isinstance(x, list)) or (not isinstance(y, list)): raise TypeError("requres {list} data") return cls(x=x, y=y) @classmethod def createByNumpyList(cls, x=None, y=None): if (not isinstance(x, np.ndarray)) or (not isinstance(y, np.ndarray)): raise TypeError("requres {np.ndarray} data") return cls(x=list(x), y=list(y)) #need np arrays converted to list for plotting if __name__ == "__main__": print("yaay") #spd2 = ScatterPlotData.createByNumpyList(x=np.array([0, 1, 2]), y=np.array([8, 12, 20])) #spd2.setColor("blue") spd1 = ScatterPlotData.createByList(x=[0, 1, 2], y=[2, 4, 8]) spd1.setColor("blue") spd1.setHasLine(False) spd1.initCmd() #this is needed to initialize plot command. without this line get defaults spd2 = ScatterPlotData.createByNumpyList(x=np.array([0, 1, 2]), y=np.array([8, 12, 20])) spd2.setColor("green") spd2.setHasLine(True) spd2.initCmd() #sp = ScatterPlot(data=[spd1], nrows=1, ncols=1) #sp.doIt() #print(spd1.getPyplotCmd()) #print(spd2.getPyplotCmd()) #sp = ScatterPlot(data=[spd2], nrows=1, ncols=1) #sp.doIt() sp = ScatterPlot(data=[spd2 for i in range(0,8)]+[spd1 for i in range(0,8)], nrows=4, ncols=4) sp.doIt() #sp = ScatterPlot(data=[spd1, spd2], nrows=1, ncols=2) #sp.setFontSize(14) #sp.doIt() #sp = ScatterPlot(data=[[spd1, spd2]], nrows=1, ncols=1) #sp.doIt()
Happy Sketching!
Inspired by,
Nils Gehlenborg