{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# GLVQ Benchmark\nThis example shows the differences between the 4 different GLVQ implementations and LMNN.\nThe Image Segmentation dataset is used for training and test. Each plot shows the projection\nand classification from each implementation. Because Glvq can't project the data on its own\na PCA is used.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from __future__ import with_statement\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom metric_learn import LMNN\nfrom sklearn.decomposition import PCA\n\nfrom sklearn_lvq import GlvqModel, GrlvqModel, LgmlvqModel, GmlvqModel\nfrom sklearn_lvq.utils import _to_tango_colors, _tango_color\n\nprint(__doc__)\n\n\ndef plot(data, target, target_p, prototype, prototype_label, p):\n    p.scatter(data[:, 0], data[:, 1], c=_to_tango_colors(target, 0), alpha=0.5)\n    p.scatter(data[:, 0], data[:, 1], c=_to_tango_colors(target_p, 0),\n              marker='.')\n    p.scatter(prototype[:, 0], prototype[:, 1],\n              c=_tango_color('aluminium', 5), marker='D')\n    try:\n        p.scatter(prototype[:, 0], prototype[:, 1], s=60,\n                  c=_to_tango_colors(prototype_label, 0), marker='.')\n    except:\n        p.scatter(prototype[:, 0], prototype[:, 1], s=60,\n                  c=_tango_color(prototype_label), marker='.')\n    p.axis('equal')\n\n\ny = []\nx = []\nwith open('segmentation.data') as f:\n    for line in f:\n        v = line.split(',')\n        y.append(v[0])\n        x.append(v[1:])\nx = np.asarray(x, dtype='float64')\ny = np.asarray(y)\n\nlmnn = LMNN(k=5, learn_rate=1e-6)\nlmnn.fit(x, y)\nx_t = lmnn.transform(x)\n\np1 = plt.subplot(231)\np1.scatter(x_t[:, 0], x_t[:, 1], c=_to_tango_colors(y, 0))\np1.axis('equal')\np1.set_title('LMNN')\n\n# GLVQ\nglvq = GlvqModel()\nglvq.fit(x, y)\np2 = plt.subplot(232)\np2.set_title('GLVQ')\nplot(PCA().fit_transform(x), y, glvq.predict(x), glvq.w_, glvq.c_w_, p2)\n\n# GRLVQ\ngrlvq = GrlvqModel()\ngrlvq.fit(x, y)\np3 = plt.subplot(233)\np3.set_title('GRLVQ')\nplot(grlvq.project(x, 2),\n     y, grlvq.predict(x), grlvq.project(grlvq.w_, 2),\n     grlvq.c_w_, p3)\n\n# GMLVQ\ngmlvq = GmlvqModel()\ngmlvq.fit(x, y)\np4 = plt.subplot(234)\np4.set_title('GMLVQ')\nplot(gmlvq.project(x, 2),\n     y, gmlvq.predict(x), gmlvq.project(gmlvq.w_, 2),\n     gmlvq.c_w_, p4)\n\n# LGMLVQ\nlgmlvq = LgmlvqModel()\nlgmlvq.fit(x, y)\np5 = plt.subplot(235)\nelem_set = list(set(lgmlvq.c_w_))\np5.set_title('LGMLVQ 1')\nplot(lgmlvq.project(x, 1, 2, True),\n     y, lgmlvq.predict(x), lgmlvq.project(np.array([lgmlvq.w_[1]]), 1, 2),\n     elem_set.index(lgmlvq.c_w_[1]), p5)\np6 = plt.subplot(236)\np6.set_title('LGMLVQ 2')\nplot(lgmlvq.project(x, 6, 2, True),\n     y, lgmlvq.predict(x), lgmlvq.project(np.array([lgmlvq.w_[6]]), 6, 2),\n     elem_set.index(lgmlvq.c_w_[6]), p6)\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}