from tkinter import * import math from R2Graph import * import lparser parser = lparser.Parser() functionDefined = False def func(x, y): if functionDefined: return parser.evaluate(x = x, y = y) else: return (x-2)^2 + (y-1)^2/4 def gradient(f, x = 0., y = 0.): h = 0.01 h2 = h*2. if (type(x) == R2Point or type(x) == tuple or type(x) == list): xx = x[0] yy = x[1] else: xx = x yy = y dx = (f(xx + h, yy) - f(xx - h, yy))/h2 dy = (f(xx, yy + h) - f(xx, yy - h))/h2 return R2Vector(dx, dy) EPS = 1e-5 SCALEX = 40 SCALEY = SCALEX STEPX = 5./SCALEX # 5 pixels STEPY = 5./SCALEX MAX_GRADIENT_NORM = 1e3 def main(): root = Tk() root.title("Gradient Descent") root.geometry("1000x600") panel = Frame(root) drawArea = Canvas(root, bg="white") panel.pack(side=TOP, fill=X) drawButton = Button(panel, text="Draw") clearButton = Button(panel, text="Clear") functionLabel = Label(panel, text="f(x):", fg="DarkBlue") functionText = StringVar() # Control variable connected with text entry functionEntry = Entry( panel, bg="white", textvariable=functionText, fg="DarkBlue", width=20 ) message = Label(panel, text="", fg="DarkBlue") stepLabel = Label(panel, text="Step:", fg="DarkBlue") stepScale = Scale( panel, from_=0.01, to=1., resolution=0.01, orient=HORIZONTAL, fg="DarkBlue" ) stepScale.set(0.1) momentumLabel = Label(panel, text="Momentum:", fg="DarkGreen") momentumScale = Scale( panel, from_=0., to=0.99, resolution=0.01, orient=HORIZONTAL, fg="DarkGreen" ) momentumScale.set(0.3) drawButton.pack(side=LEFT, padx=4, pady=4) clearButton.pack(side=LEFT, padx=4, pady=4) functionLabel.pack(side=LEFT, padx=4, pady=4) functionEntry.pack(side=LEFT, fill=X, expand=True, padx=4, pady=4) message.pack(side=LEFT, padx=4, pady=4) stepLabel.pack(side=LEFT, padx=4, pady=4) stepScale.pack(side=LEFT, padx=4, pady=4) momentumLabel.pack(side=LEFT, padx=4, pady=4) momentumScale.pack(side=LEFT, padx=4, pady=4) descentLine = [] descentLineIDs = [] # heavyBallLine = [] # heavyBallLineIDs = [] # nesterovAccelerationLine = [] # nesterovAccelerationLineIDs = [] drawArea.pack(side=TOP, fill=BOTH, expand=True, padx=4, pady=4) root.update() def map(t): '''Map mathematical coordinates into pixel coordinates R2Point --> (x, y)''' w = drawArea.winfo_width() h = drawArea.winfo_height() ox = w/2. oy = h/2. return (ox + t.x*SCALEX, oy - t.y*SCALEY) def invmap(p): '''Map pixel coordinates into mathematical coordinates (x, y) --> R2Point''' w = drawArea.winfo_width() h = drawArea.winfo_height() ox = w/2. oy = h/2. x = float(p[0] - ox)/SCALEX y = float(oy - p[1])/SCALEY return R2Point(x, y) def xmin(): w = drawArea.winfo_width() return (-w/2.)/SCALEX def xmax(): w = drawArea.winfo_width() return (w/2.)/SCALEX def ymin(): h = drawArea.winfo_height() return (-h/2.)/SCALEY def ymax(): h = drawArea.winfo_height() return (h/2.)/SCALEY def eraseObjects(idList): # drawArea.delete(idList) for x in idList: drawArea.delete(x) idList.clear() def drawGrid(): x0 = xmin(); x1 = xmax() y0 = ymin(); y1 = ymax() # Grid for x in range(int(x0), int(x1) + 1): if x == 0: continue drawArea.create_line( map(R2Point(x, y0)), map(R2Point(x, y1)), fill="lightGray" ) for y in range(int(y0), int(y1) + 1): if y == 0: continue drawArea.create_line( map(R2Point(x0, y)), map(R2Point(x1, y)), fill="lightGray" ) # Coordinate axes drawArea.create_line( map(R2Point(x0, 0.)), map(R2Point(x1, 0.)), fill="black", width=2 ) drawArea.create_line( map(R2Point(0., y0)), map(R2Point(0., y1)), fill="black", width=2 ) def drawLevelLine(f, level = 0., color="gray"): x0 = xmin(); x1 = xmax() y0 = ymin(); y1 = ymax() y = y0; v1 = [] while y < y1: if y > y0: v0 = v1.copy() else: v0 = [] v1.clear() x = x0; while x <= x1 + STEPX: if y <= y0: z = f(x, y) v0.append(z) z = f(x, y + STEPY) v1.append(z) x += STEPX; x = x0; ix = 0 while x < x1: # z0--z2 # | / # z1 z0 = v0[ix] z2 = v0[ix + 1] z1 = v1[ix] z3 = v1[ix + 1] points = [False]*5 # z0--0--z2 # | /| # 1 2 4 # | / | # z1--3--z3 if (z0 <= level and level < z2) or (z0 >= level and level > z2): xx = x + STEPX*abs(level - z0)/abs(z2 - z0) p02 = R2Point(xx, y) points[0] = True if (z0 <= level and level < z1) or (z0 >= level and level > z1): yy = y + STEPY*abs(level - z0)/abs(z1 - z0) p01 = R2Point(x, yy) points[1] = True if (z1 <= level and level < z2) or (z1 >= level and level > z2): xx = x + STEPX*abs(level - z1)/abs(z2 - z1) yy = (y + STEPY) - STEPY*abs(level - z1)/abs(z2 - z1) p12 = R2Point(xx, yy) points[2] = True # z2 # / | # z1--z3 if (z1 <= level and level < z3) or (z1 >= level and level > z3): xx = x + STEPX*abs(level - z1)/abs(z3 - z1) p13 = R2Point(xx, y + STEPY) points[3] = True if (z2 <= level and level < z3) or (z2 >= level and level > z3): yy = y + STEPY*abs(level - z2)/abs(z3 - z2) p23 = R2Point(x + STEPX, yy) points[4] = True # z0--0--z2 # | /| # 1 2 4 # | / | # z1--3--z3 # Upper triangle: if points[0] and points[1]: drawArea.create_line(map(p02), map(p01), fill=color) if points[0] and points[2]: drawArea.create_line(map(p02), map(p12), fill=color) if points[1] and points[2]: drawArea.create_line(map(p01), map(p12), fill=color) # Lower triangle: if points[2] and points[4]: drawArea.create_line(map(p12), map(p23), fill=color) if points[2] and points[3]: drawArea.create_line(map(p12), map(p13), fill=color) if points[3] and points[4]: drawArea.create_line(map(p13), map(p23), fill=color) x += STEPX; ix += 1 y += STEPY; def plotFunc(f, color="blue"): vMin = 1e+30 vMax = -1e+30 x0 = xmin(); x1 = xmax() y0 = ymin(); y1 = ymax() x = x0 while x <= x1: v = f(x, 0.) if v < vMin: vMin = v if v > vMax: vMax = v v = f(x, 1.) if v < vMin: vMin = v if v > vMax: vMax = v v = f(x, -1.) if v < vMin: vMin = v if v > vMax: vMax = v x += 1. y = y0 while y <= y1: v = f(0., y) if v < vMin: vMin = v if v > vMax: vMax = v v = f(1., y) if v < vMin: vMin = v if v > vMax: vMax = v v = f(-1., y) if v < vMin: vMin = v if v > vMax: vMax = v y += 1. if vMin < -100.: vMin = -100. if vMax > 1000.: vMax = 1000. # for level in range(int(vMin), int(vMax) + 1): # drawLevelLine(f, level) magStep = 1.5 v = 1./4. while v <= vMax: drawLevelLine(f, v) # v += 1. v *= magStep def onDraw(): global functionDefined text = functionEntry.get() parser.setParseLine(text) (success, errorText) = parser.compile() message.configure(text = errorText) onClear() drawArea.delete("all") drawGrid() functionDefined = success if functionDefined: plotFunc(func, color="blue") def gradientDescent(x0): print("gradientDescent: x0 =", x0) maxSteps = 100 n = 0 alpha = stepScale.get() descentLine.clear() x = x0.copy() descentLine.append(map(x)) while n < maxSteps: g = gradient(func, x) norm_g = g.norm() if norm_g <= EPS: break if norm_g > MAX_GRADIENT_NORM: g *= MAX_GRADIENT_NORM/norm_g x -= g*alpha # print(x) descentLine.append(map(x)) n += 1 print("min =", x, "iterations:", n) eraseObjects(descentLineIDs) lineID = drawArea.create_line(descentLine, fill="red", width=2) descentLineIDs.append(lineID) textPoint = map(R2Point(xmin() + 0.2, ymax() - 1)) textID = drawArea.create_text( textPoint, text = "Gradient Descent, iterations: " + str(n), anchor="nw", font=("Times", 18), fill="red" ) descentLineIDs.append(textID) def heavyBall(x0): print("heavyBall: x0 =", x0) print("Not implemented yet...") def nesterovAcceleration(x0): print("nesterovAcceleration: x0 =", x0) print("Not implemented yet...") def onClear(): eraseObjects(descentLineIDs) # eraseObjects(heavyBallLineIDs) # eraseObjects(nesterovAccelerationLineIDs) drawButton.configure(command = onDraw) clearButton.configure(command = onClear) def redraw(): drawArea.delete("all") drawGrid() if functionDefined: plotFunc(func, color="blue") def onConfigure(e): redraw() drawArea.bind("", onConfigure) functionEntry.bind("", lambda e: onDraw()) def onMouseWheel(e): '''Process MouseWheel event: scale a picture''' # print(e) global SCALEX, SCALEY increment = 0 scaleFactor = 1.2 # Respond to Linux or Windows/MacOS wheel event if e.num == 4 or e.delta > 0: increment = 1 elif e.num == 5 or e.delta < 0: increment = (-1) if increment > 0: SCALEX *= scaleFactor; SCALEY = SCALEX redraw() elif increment < 0: if SCALEX <= 8: return SCALEX /= scaleFactor; SCALEY = SCALEX redraw() def onMouseRelease(e): p = (e.x, e.y) mouseButton = e.num t = invmap(p) if mouseButton == 1: # Left button gradientDescent(t) elif mouseButton == 2: # Middle button heavyBall(t) elif mouseButton == 3: # Right button nesterovAcceleration(t) # drawArea.bind("", onMouseWheel) # for Windows/MacOS # drawArea.bind("", onMouseWheel) # for Linux (scroll up) # drawArea.bind("", onMouseWheel) # for Linux (scroll down) drawArea.bind("", onMouseRelease) drawArea.bind("", onMouseRelease) drawArea.bind("", onMouseRelease) root.update() drawGrid() functionText.set("((x-2)+(y-1))^2/4 + ((x-2)-(y-1))^2") onDraw() root.mainloop() if __name__ == "__main__": main()