{ "cells": [ { "cell_type": "code", "execution_count": 10, "id": "cd3b4cc2-0bb3-4fba-9c92-48e40f5419c4", "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "filepattern = \"tabby/dataset/data.jsonl\"\n", "api = \"http://localhost:8080\"\n", "max_records = \"3\"" ] }, { "cell_type": "code", "execution_count": 11, "id": "f12319d9", "metadata": { "tags": [ "remove" ] }, "outputs": [], "source": [ "max_records = int(max_records)" ] }, { "cell_type": "code", "execution_count": 64, "id": "172d7105-ecac-4019-bbe1-dcd70ed6af60", "metadata": { "tags": [ "remove" ] }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "from tabby_client import Client\n", "from tabby_client.api.v1 import health\n", "from tabby_client.api.v1 import completion\n", "\n", "from tabby_client.models import CompletionRequest, CompletionRequest, Segments, Choice\n", "\n", "import processing\n", "import editdistance\n", "import random\n", "\n", "\n", "def valid_item(item: processing.Item):\n", " count_body_lines = len(item.body.splitlines())\n", "\n", " if count_body_lines > 10:\n", " return False\n", "\n", " return True\n", "\n", "\n", "def scorer(label, prediction):\n", " distance = editdistance.eval(label, prediction)\n", " return max(0.0, 1.0 - distance / len(label))\n", "\n", "\n", "def run_eval():\n", " client = Client(base_url=api, timeout=50)\n", " try:\n", " health.sync(client=client)\n", " except:\n", " print(f\"Tabby Server is not ready, please check if '{api}' is correct.\")\n", " return\n", " \n", " items = [x for x in processing.items_from_filepattern(filepattern) if valid_item(x)];\n", " if len(items) > max_records:\n", " random.seed(0xbadbeef)\n", " items = random.sample(items, max_records)\n", " \n", "\n", " for item in items:\n", " if not valid_item(item):\n", " continue\n", "\n", " request = CompletionRequest(\n", " language=item.language, segments=Segments(prefix=item.prefix)\n", " )\n", "\n", " resp: CompletionResponse = completion.sync(client=client, json_body=request)\n", " label = item.body\n", " prediction = resp.choices[0].text\n", "\n", " block_score = scorer(label, prediction)\n", " \n", " label_lines = label.splitlines()\n", " prediction_lines = prediction.splitlines()\n", " \n", " if len(label_lines) > 0 and len(prediction_lines) > 0:\n", " line_score = scorer(label_lines[0], prediction_lines[0])\n", "\n", " yield dict(\n", " prompt=item.prefix,\n", " prediction=prediction,\n", " label=label,\n", " block_score=block_score,\n", " line_score=line_score,\n", " )" ] }, { "cell_type": "code", "execution_count": 65, "id": "76c08e41-42fc-486a-96b3-5cf647635e90", "metadata": { "tags": [ "remove" ] }, "outputs": [], "source": [ "df = pd.DataFrame(list(run_eval()))" ] }, { "cell_type": "code", "execution_count": 66, "id": "038f9c95-edf4-463a-a600-d1945b17c235", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([], dtype=object)" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAABL4AAAHDCAYAAAAqZtO0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA4iUlEQVR4nO3de5xVdbk/8GdGh0EOjEgEyJ0U74i3UDAPpOCIaHJSLOwEImoXSYnUX1YK6Cks7yc1NVOsDqFiYKcQHREkc9SDikcsPWKopQyK5AyXHEdm/f4oRseZgdnArHEv3+/Xa/7Ya3/X3s96WK/N8/rM2msKkiRJAgAAAAAyprC1CwAAAACAliD4AgAAACCTBF8AAAAAZJLgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+IKPqZkzZ0ZBQUG8/PLLERExbNiwGDZsWKvWBABA65o2bVoUFBTEmjVrtriub9++cfrpp7dIDS+//HIUFBTElVde2SKvD3y8CL4AAAAAyKSdW7sA4KPhgQceaO0SAAAAYIdyxRcQERFt2rSJNm3atHYZre69996Ld999t7XLAAAgA2pra+Odd95p7TLgY03wBUREw3t8LV68OAoKCuKuu+6K73//+9GzZ89o27ZtHHPMMbFixYoG+z/++ONx3HHHxa677hrt2rWLoUOHxh/+8Iec65g9e3Yceuih0aFDhygpKYkBAwbEddddV2/N22+/Hd/85jejb9++UVxcHD179oxx48bVuxfFG2+8ERMnToyuXbtG27ZtY+DAgXHHHXfUe50P3j/i2muvjT322COKi4vjj3/8Y0REPP/883HKKadEp06dom3btnHYYYfFb37zm5yPCQAg36xZsyZOPfXUKCkpiU984hNx3nnnbTXA+fOf/xxjxoyJTp06Rbt27eKII46I3/3udw3WvfPOOzFt2rTYa6+9om3btrH77rvH5z//+XjppZeafO0kSeLss8+ONm3axK9//etmH0dZWVl85jOfiY4dO0b79u1j7733ju985zs517Nhw4b41re+Fb169Yri4uLYe++948orr4wkSeq9VkFBQUyaNCn+67/+K/bff/8oLi6OBQsWRETEa6+9FmeccUZ07do1iouLY//994/bbrut2ccCbBtfdQS26PLLL4/CwsI4//zzo7KyMn70ox/Fl770pXj88cfr1jz00EMxcuTIOPTQQ2Pq1KlRWFgYt99+exx99NHx+9//PgYNGtSs9yorK4uxY8fGMcccEz/84Q8jIuJPf/pT/OEPf4jzzjsvIiLWr18fRx11VPzpT3+KM844Iw455JBYs2ZN/OY3v4m//vWv0blz5/j73/8ew4YNixUrVsSkSZOiX79+cffdd8fpp58eb7/9dt1rbXb77bfHO++8E2effXYUFxdHp06d4rnnnosjjzwyevToEd/+9rfjX/7lX+Kuu+6K0aNHxz333BP/9m//toM6DADw0XPqqadG3759Y8aMGfHYY4/Ff/7nf8bf/va3+PnPf97o+tWrV8eQIUNi48aNce6558YnPvGJuOOOO+Jzn/tczJkzp2522rRpU5xwwgmxcOHC+OIXvxjnnXderFu3LsrKymL58uWxxx57NHjtTZs2xRlnnBF33nlnzJ07N0aNGtWsY3juuefihBNOiAMPPDAuvfTSKC4ujhUrVtT75Wxz6kmSJD73uc/FokWLYuLEiXHQQQfF/fffHxdccEG89tprcc0119R734ceeijuuuuumDRpUnTu3Dn69u0bq1evjiOOOKIuGPvkJz8Z9913X0ycODGqqqpi8uTJzfyXAXKWAB9Lt99+exIRycqVK5MkSZKhQ4cmQ4cOrXt+0aJFSUQk++67b1JdXV23/brrrksiInn22WeTJEmS2trapH///klpaWlSW1tbt27jxo1Jv379khEjRjS7pvPOOy8pKSlJ3nvvvSbXXHLJJUlEJL/+9a8bPLf5/a+99tokIpJf/vKXdc+9++67yeDBg5P27dsnVVVVSZIkycqVK5OISEpKSpI33nij3msdc8wxyYABA5J33nmn3usPGTIk6d+/f7OPCQAgn0ydOjWJiORzn/tcve1f//rXk4hInnnmmSRJkqRPnz7J+PHj656fPHlyEhHJ73//+7pt69atS/r165f07ds32bRpU5IkSXLbbbclEZFcffXVDd578yy3eUa74oorkpqamuQLX/hCsssuuyT3339/TsdyzTXXJBGRvPnmm02uaU498+bNSyIi+Y//+I96z59yyilJQUFBsmLFirptEZEUFhYmzz33XL21EydOTHbfffdkzZo19bZ/8YtfTHbddddk48aNOR0b0Hy+6ghs0YQJE+rd++uoo46KiH9cyh4RsWzZsnjxxRfjtNNOi7feeivWrFkTa9asiQ0bNsQxxxwTS5Ysidra2ma9V8eOHWPDhg1RVlbW5Jp77rknBg4c2OgVVwUFBRERMX/+/OjWrVuMHTu27rmioqI499xzY/369fHwww/X2+/kk0+OT37yk3WP165dGw899FCceuqpsW7durpjeuutt6K0tDRefPHFeO2115p1TAAA+eicc86p9/gb3/hGRPxjzmrM/PnzY9CgQfGZz3ymblv79u3j7LPPjpdffrnuVhL33HNPdO7cue71PmjzLLfZu+++G2PGjInf/va3MX/+/Dj22GNzOoaOHTtGRMS9997b5DzanHrmz58fO+20U5x77rn1nv/Wt74VSZLEfffdV2/70KFDY7/99qt7nCRJ3HPPPXHiiSdGkiR1s+WaNWuitLQ0Kisr46mnnsrp2IDmE3wBW9S7d+96j3fbbbeIiPjb3/4WEREvvvhiRESMHz8+PvnJT9b7ufXWW6O6ujoqKyub9V5f//rXY6+99oqRI0dGz54944wzzqi7J8JmL730UhxwwAFbfJ1XXnkl+vfvH4WF9T/i9t1337rnP6hfv371Hq9YsSKSJImLL764wTFNnTo1Iv5xDzEAgKzq379/vcd77LFHFBYWxssvv9zo+ldeeSX23nvvBts/PH+99NJLsffee8fOO2/9rjszZsyIefPmxZw5c+rdi7a5vvCFL8SRRx4ZZ555ZnTt2jW++MUvxl133VUvBGtOPa+88kp07949OnTosMVj2+zDs+Wbb74Zb7/9dtxyyy0NZssJEyZEhNkSWpJ7fAFbtNNOOzW6PfnnjTw3Dw5XXHFFHHTQQY2ubd++fbPeq0uXLrFs2bK4//7747777ov77rsvbr/99hg3blyDG9PvSLvssku9x5uP6fzzz4/S0tJG99lzzz1brB4AgI+aD1+NlYbS0tJYsGBB/OhHP4phw4ZF27Ztc9p/l112iSVLlsSiRYvid7/7XSxYsCDuvPPOOProo+OBBx5ocs7dXk3Nlv/+7/8e48ePb3SfAw88sEVqAQRfwHbafAPSkpKSGD58+Ha/Xps2beLEE0+ME088MWpra+PrX/963HzzzXHxxRfHnnvuGXvssUcsX758i6/Rp0+f+N///d+ora2td9XX888/X/f8lnzqU5+KiH98PXJHHBMAQL558cUX6125tGLFiqitrY2+ffs2ur5Pnz7xwgsvNNj+4flrjz32iMcffzxqamqiqKhoizUcccQR8dWvfjVOOOGEGDNmTMydO7dZV4p9UGFhYRxzzDFxzDHHxNVXXx0/+MEP4rvf/W4sWrQohg8f3qx6+vTpEw8++GCsW7eu3lVfzZ0tP/nJT0aHDh1i06ZNZktoBb7qCGyXQw89NPbYY4+48sorY/369Q2ef/PNN5v9Wm+99Va9x4WFhXW//aquro6If9yP65lnnom5c+c22H/zVWjHH398VFRUxJ133ln33HvvvRc//vGPo3379jF06NAt1tGlS5cYNmxY3HzzzbFq1artOiYAgHx0ww031Hv84x//OCIiRo4c2ej6448/Pp544okoLy+v27Zhw4a45ZZbom/fvnX3vDr55JNjzZo1cf311zd4jc2z3AcNHz48Zs+eHQsWLIgvf/nLzb53bMQ/7tv6YZu/ofDB2XJr9Rx//PGxadOmBmuuueaaKCgoaLInm+20005x8sknxz333NPoL3DNltCyXPEFbJfCwsK49dZbY+TIkbH//vvHhAkTokePHvHaa6/FokWLoqSkJP77v/+7Wa915plnxtq1a+Poo4+Onj17xiuvvBI//vGP46CDDqq7h8IFF1wQc+bMiTFjxsQZZ5wRhx56aKxduzZ+85vfxE033RQDBw6Ms88+O26++eY4/fTT48knn4y+ffvGnDlz4g9/+ENce+21De7P0JgbbrghPvOZz8SAAQPirLPOik996lOxevXqKC8vj7/+9a/xzDPPbFffAAA+ylauXBmf+9zn4rjjjovy8vL45S9/GaeddloMHDiw0fXf/va341e/+lWMHDkyzj333OjUqVPccccdsXLlyrjnnnvqrsIfN25c/PznP48pU6bEE088EUcddVRs2LAhHnzwwfj6178eJ510UoPXHj16dN3tL0pKSuLmm29u1jFceumlsWTJkhg1alT06dMn3njjjbjxxhujZ8+edTfhb049J554Ynz2s5+N7373u/Hyyy/HwIED44EHHoh77703Jk+eXPcNiC25/PLLY9GiRXH44YfHWWedFfvtt1+sXbs2nnrqqXjwwQcbDemAHaT1/qAk0Jpuv/32JCKSlStXJkmSJEOHDk2GDh1a9/yiRYuSiEjuvvvuevtt/vPSt99+e73tTz/9dPL5z38++cQnPpEUFxcnffr0SU499dRk4cKFza5pzpw5ybHHHpt06dIladOmTdK7d+/kK1/5SrJq1ap66956661k0qRJSY8ePZI2bdokPXv2TMaPH1/vz0OvXr06mTBhQtK5c+ekTZs2yYABAxrU/ME/ld2Yl156KRk3blzSrVu3pKioKOnRo0dywgknJHPmzGn2MQEA5JOpU6cmEZH88Y9/TE455ZSkQ4cOyW677ZZMmjQp+fvf/163rk+fPsn48ePr7fvSSy8lp5xyStKxY8ekbdu2yaBBg5Lf/va3Dd5j48aNyXe/+92kX79+SVFRUdKtW7fklFNOSV566aUkSZqe0W688cYkIpLzzz+/WceycOHC5KSTTkq6d++etGnTJunevXsyduzY5P/+7/9yqidJkmTdunXJN7/5zaR79+5JUVFR0r9//+SKK65Iamtr671WRCTnnHNOo/WsXr06Oeecc5JevXrVvc8xxxyT3HLLLc06HmDbFCRJI9eTAgAAAECec48vAAAAADLJPb6AFrdp06at3rSzffv20b59+5QqAgAgn1VUVGzx+V122SV23XXXlKoBPsp81RFocS+//HK9P4fdmKlTp8a0adPSKQgAgLxWUFCwxefHjx8fM2fOTKcY4CPNFV9Ai+vWrVuUlZVtcc2nPvWplKoBACDfbW227N69e0qVAB91rvgCAAAAIJPc3B4AAACATMqLrzrW1tbG66+/Hh06dNjqd7kBgPyRJEmsW7cuunfvHoWFfh/HjmWGBIBsymWGzIvg6/XXX49evXq1dhkAQAv5y1/+Ej179mztMsgYMyQAZFtzZsi8CL46dOgQEf84oJKSklau5qOhpqYmHnjggTj22GOjqKiotcvJPP1Oj16nS7/Tpd8NVVVVRa9ever+r4cdyQzZkM+hdOl3uvQ7PXqdLv1uKJcZMi+Cr82XppeUlBha/qmmpibatWsXJSUlTvwU6Hd69Dpd+p0u/W6ar6HREsyQDfkcSpd+p0u/06PX6dLvpjVnhnQzDQAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJgi8AAAAAMknwBQAAAEAmCb4AAAAAyKScgq+f/OQnceCBB0ZJSUmUlJTE4MGD47777tviPnfffXfss88+0bZt2xgwYEDMnz9/uwoGAAAAgObIKfjq2bNnXH755fHkk0/G0qVL4+ijj46TTjopnnvuuUbXP/roozF27NiYOHFiPP300zF69OgYPXp0LF++fIcUDwAAAABNySn4OvHEE+P444+P/v37x1577RXf//73o3379vHYY481uv66666L4447Li644ILYd99947LLLotDDjkkrr/++h1SPAAAAAA0ZZvv8bVp06aYPXt2bNiwIQYPHtzomvLy8hg+fHi9baWlpVFeXr6tbwsAAAAAzbJzrjs8++yzMXjw4HjnnXeiffv2MXfu3Nhvv/0aXVtRURFdu3att61r165RUVGxxfeorq6O6urqusdVVVUREVFTUxM1NTW5lpxJm/ugH+nQ7/Todbr0O1363ZBesCOZIbfO51C69Dtd+p0evU6XfjeUSy9yDr723nvvWLZsWVRWVsacOXNi/Pjx8fDDDzcZfm2LGTNmxPTp0xtsf+CBB6Jdu3Y77H2yoKysrLVL+FjR7/Todbr0O136/b6NGze2dglkiBmy+XwOpUu/06Xf6dHrdOn3+3KZIQuSJEm2582GDx8ee+yxR9x8880Nnuvdu3dMmTIlJk+eXLdt6tSpMW/evHjmmWeafM3GflvXq1evWLNmTZSUlGxPuZlRU1MTZWVlMWLEiCgqKmrtcjJPv9OzudcXLy2M6tqC1i4nJ8unlbZ2CTlzbqdLvxuqqqqKzp07R2Vlpf/j2W5myK3zOZQu/U5Xvs6RZki2Rr8bymWGzPmKrw+rra2tN2B80ODBg2PhwoX1gq+ysrIm7wm2WXFxcRQXFzfYXlRU5B/5Q/QkXfqdnuragqjelD8DS0Tk9bnh3E6Xfr9PH9iRzJDNpyfp0u905dscmc/nhnM7Xfr9vlz6kFPwddFFF8XIkSOjd+/esW7dupg1a1YsXrw47r///oiIGDduXPTo0SNmzJgRERHnnXdeDB06NK666qoYNWpUzJ49O5YuXRq33HJLLm8LAAAAADnLKfh64403Yty4cbFq1arYdddd48ADD4z7778/RowYERERr776ahQWvv+HIocMGRKzZs2K733ve/Gd73wn+vfvH/PmzYsDDjhgxx4FAAAAAHxITsHXz372sy0+v3jx4gbbxowZE2PGjMmpKAAAAADYXoVbXwIAAAAA+UfwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJgi8AAAAAMknwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJgi8AAAAAMknwBQAAAEAm5RR8zZgxIz796U9Hhw4dokuXLjF69Oh44YUXtrjPzJkzo6CgoN5P27Ztt6toAAAAANianIKvhx9+OM4555x47LHHoqysLGpqauLYY4+NDRs2bHG/kpKSWLVqVd3PK6+8sl1FAwAAAMDW7JzL4gULFtR7PHPmzOjSpUs8+eST8a//+q9N7ldQUBDdunXbtgoBAAAAYBts1z2+KisrIyKiU6dOW1y3fv366NOnT/Tq1StOOumkeO6557bnbQEAAABgq3K64uuDamtrY/LkyXHkkUfGAQcc0OS6vffeO2677bY48MADo7KyMq688soYMmRIPPfcc9GzZ89G96muro7q6uq6x1VVVRERUVNTEzU1NdtacqZs7oN+pEO/07O5x8WFSStXkrt8PD+c2+nS74b0gh3JDLl1PofSpd/pytc5Mh/PD+d2uvS7oVx6UZAkyTZ9Knzta1+L++67Lx555JEmA6ymitt3331j7NixcdlllzW6Ztq0aTF9+vQG22fNmhXt2rXblnIBgI+gjRs3xmmnnRaVlZVRUlLS2uWQ58yQAPDxkMsMuU3B16RJk+Lee++NJUuWRL9+/XIucMyYMbHzzjvHr371q0afb+y3db169Yo1a9YYiv+ppqYmysrKYsSIEVFUVNTa5WSefqdnc68vXloY1bUFrV1OTpZPK23tEnLm3E6XfjdUVVUVnTt3FnyxQ5ght87nULr0O135OkeaIdka/W4olxkyp686JkkS3/jGN2Lu3LmxePHibQq9Nm3aFM8++2wcf/zxTa4pLi6O4uLiBtuLior8I3+InqRLv9NTXVsQ1ZvyZ2CJiLw+N5zb6dLv9+kDO5IZsvn0JF36na58myPz+dxwbqdLv9+XSx9yCr7OOeecmDVrVtx7773RoUOHqKioiIiIXXfdNXbZZZeIiBg3blz06NEjZsyYERERl156aRxxxBGx5557xttvvx1XXHFFvPLKK3HmmWfm8tYAAAAAkJOcgq+f/OQnERExbNiwettvv/32OP300yMi4tVXX43Cwvf/WOTf/va3OOuss6KioiJ22223OPTQQ+PRRx+N/fbbb/sqBwAAAIAtyPmrjluzePHieo+vueaauOaaa3IqCgAAAAC2V+HWlwAAAABA/hF8AQAAAJBJgi8AAAAAMknwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJgi8AAAAAMknwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJOQVfM2bMiE9/+tPRoUOH6NKlS4wePTpeeOGFre539913xz777BNt27aNAQMGxPz587e5YAAAAABojpyCr4cffjjOOeeceOyxx6KsrCxqamri2GOPjQ0bNjS5z6OPPhpjx46NiRMnxtNPPx2jR4+O0aNHx/Lly7e7eAAAAABoys65LF6wYEG9xzNnzowuXbrEk08+Gf/6r//a6D7XXXddHHfccXHBBRdERMRll10WZWVlcf3118dNN920jWUDAAAAwJZt1z2+KisrIyKiU6dOTa4pLy+P4cOH19tWWloa5eXl2/PWAAAAALBFOV3x9UG1tbUxefLkOPLII+OAAw5ocl1FRUV07dq13rauXbtGRUVFk/tUV1dHdXV13eOqqqqIiKipqYmampptLTlTNvdBP9Kh3+nZ3OPiwqSVK8ldPp4fzu106XdDesGOZIbcOp9D6dLvdOXrHJmP54dzO1363VAuvShIkmSbPhW+9rWvxX333RePPPJI9OzZs8l1bdq0iTvuuCPGjh1bt+3GG2+M6dOnx+rVqxvdZ9q0aTF9+vQG22fNmhXt2rXblnIBgI+gjRs3xmmnnRaVlZVRUlLS2uWQ58yQAPDxkMsMuU3B16RJk+Lee++NJUuWRL9+/ba4tnfv3jFlypSYPHly3bapU6fGvHnz4plnnml0n8Z+W9erV69Ys2aNofifampqoqysLEaMGBFFRUWtXU7m6Xd6Nvf64qWFUV1b0Nrl5GT5tNLWLiFnzu106XdDVVVV0blzZ8EXO4QZcut8DqVLv9OVr3OkGZKt0e+Gcpkhc/qqY5Ik8Y1vfCPmzp0bixcv3mroFRExePDgWLhwYb3gq6ysLAYPHtzkPsXFxVFcXNxge1FRkX/kD9GTdOl3eqprC6J6U/4MLBGR1+eGcztd+v0+fWBHMkM2n56kS7/TlW9zZD6fG87tdOn3+3LpQ07B1znnnBOzZs2Ke++9Nzp06FB3n65dd901dtlll4iIGDduXPTo0SNmzJgRERHnnXdeDB06NK666qoYNWpUzJ49O5YuXRq33HJLLm8NAAAAADnJ6a86/uQnP4nKysoYNmxY7L777nU/d955Z92aV199NVatWlX3eMiQITFr1qy45ZZbYuDAgTFnzpyYN2/eFm+IDwAAAADbK+evOm7N4sWLG2wbM2ZMjBkzJpe3AgAAAIDtktMVXwAAAACQLwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJgi8AAAAAMknwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJgi8AAAAAMknwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATMo5+FqyZEmceOKJ0b179ygoKIh58+Ztcf3ixYujoKCgwU9FRcW21gwAAAAAW5Vz8LVhw4YYOHBg3HDDDTnt98ILL8SqVavqfrp06ZLrWwMAAABAs+2c6w4jR46MkSNH5vxGXbp0iY4dO+a8HwAAAABsi5yDr2110EEHRXV1dRxwwAExbdq0OPLII5tcW11dHdXV1XWPq6qqIiKipqYmampqWrzWfLC5D/qRDv1Oz+YeFxcmrVxJ7vLx/HBup0u/G9ILdiQz5Nb5HEqXfqcrX+fIfDw/nNvp0u+GculFQZIk2/ypUFBQEHPnzo3Ro0c3ueaFF16IxYsXx2GHHRbV1dVx6623xi9+8Yt4/PHH45BDDml0n2nTpsX06dMbbJ81a1a0a9duW8sFAD5iNm7cGKeddlpUVlZGSUlJa5dDnjNDAsDHQy4zZIsHX40ZOnRo9O7dO37xi180+nxjv63r1atXrFmzxlD8TzU1NVFWVhYjRoyIoqKi1i4n8/Q7PZt7ffHSwqiuLWjtcnKyfFppa5eQM+d2uvS7oaqqqujcubPgix3CDLl1PofSpd/pytc50gzJ1uh3Q7nMkKl91fGDBg0aFI888kiTzxcXF0dxcXGD7UVFRf6RP0RP0qXf6amuLYjqTfkzsEREXp8bzu106ff79IEdyQzZfHqSLv1OV77Nkfl8bji306Xf78ulDzn/VccdYdmyZbH77ru3xlsDAAAA8DGR8xVf69evjxUrVtQ9XrlyZSxbtiw6deoUvXv3josuuihee+21+PnPfx4REddee23069cv9t9//3jnnXfi1ltvjYceeigeeOCBHXcUAAAAAPAhOQdfS5cujc9+9rN1j6dMmRIREePHj4+ZM2fGqlWr4tVXX617/t13341vfetb8dprr0W7du3iwAMPjAcffLDeawAAAADAjpZz8DVs2LDY0v3wZ86cWe/xhRdeGBdeeGHOhQEAAADA9miVe3wBAAAAQEsTfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJgi8AAAAAMknwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJgi8AAAAAMknwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSTkHX0uWLIkTTzwxunfvHgUFBTFv3ryt7rN48eI45JBDori4OPbcc8+YOXPmNpQKAAAAAM2Xc/C1YcOGGDhwYNxwww3NWr9y5coYNWpUfPazn41ly5bF5MmT48wzz4z7778/52IBAAAAoLl2znWHkSNHxsiRI5u9/qabbop+/frFVVddFRER++67bzzyyCNxzTXXRGlpaa5vDwAAAADNknPwlavy8vIYPnx4vW2lpaUxefLkJveprq6O6urqusdVVVUREVFTUxM1NTUtUme+2dwH/UiHfqdnc4+LC5NWriR3+Xh+OLfTpd8N6QU7khly63wOpUu/05Wvc2Q+nh/O7XTpd0O59KIgSZJt/lQoKCiIuXPnxujRo5tcs9dee8WECRPioosuqts2f/78GDVqVGzcuDF22WWXBvtMmzYtpk+f3mD7rFmzol27dttaLgDwEbNx48Y47bTTorKyMkpKSlq7HPKcGRIAPh5ymSFb/IqvbXHRRRfFlClT6h5XVVVFr1694thjjzUU/1NNTU2UlZXFiBEjoqioqLXLyTz9Ts/mXl+8tDCqawtau5ycLJ+Wf1/fdm6nS78b2nxFDuwIZsit8zmULv1OV77OkWZItka/G8plhmzx4Ktbt26xevXqettWr14dJSUljV7tFRFRXFwcxcXFDbYXFRX5R/4QPUmXfqenurYgqjflz8ASEXl9bji306Xf79MHdiQzZPPpSbr0O135Nkfm87nh3E6Xfr8vlz7k/FcdczV48OBYuHBhvW1lZWUxePDgln5rAAAAAD7Gcg6+1q9fH8uWLYtly5ZFRMTKlStj2bJl8eqrr0bEPy4xHzduXN36r371q/HnP/85Lrzwwnj++efjxhtvjLvuuiu++c1v7pgjAAAAAIBG5Bx8LV26NA4++OA4+OCDIyJiypQpcfDBB8cll1wSERGrVq2qC8EiIvr16xe/+93voqysLAYOHBhXXXVV3HrrrVFamn/fYwYAAAAgf+R8j69hw4bFlv4Q5MyZMxvd5+mnn871rQAAAABgm7X4Pb4AAAAAoDUIvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJgi8AAAAAMknwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJgi8AAAAAMknwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIpG0Kvm644Ybo27dvtG3bNg4//PB44oknmlw7c+bMKCgoqPfTtm3bbS4YAAAAAJoj5+DrzjvvjClTpsTUqVPjqaeeioEDB0ZpaWm88cYbTe5TUlISq1atqvt55ZVXtqtoAAAAANianIOvq6++Os4666yYMGFC7LfffnHTTTdFu3bt4rbbbmtyn4KCgujWrVvdT9euXberaAAAAADYmp1zWfzuu+/Gk08+GRdddFHdtsLCwhg+fHiUl5c3ud/69eujT58+UVtbG4ccckj84Ac/iP3337/J9dXV1VFdXV33uKqqKiIiampqoqamJpeSM2tzH/QjHfqdns09Li5MWrmS3OXj+eHcTpd+N6QX7EhmyK3zOZQu/U5Xvs6R+Xh+OLfTpd8N5dKLgiRJmv2p8Prrr0ePHj3i0UcfjcGDB9dtv/DCC+Phhx+Oxx9/vME+5eXl8eKLL8aBBx4YlZWVceWVV8aSJUviueeei549ezb6PtOmTYvp06c32D5r1qxo165dc8sFAD7iNm7cGKeddlpUVlZGSUlJa5dDnjNDAsDHQy4zZIsHXx9WU1MT++67b4wdOzYuu+yyRtc09tu6Xr16xZo1awzF/1RTUxNlZWUxYsSIKCoqau1yMk+/07O51xcvLYzq2oLWLicny6eVtnYJOXNup0u/G6qqqorOnTsLvtghzJBb53MoXfqdrnydI82QbI1+N5TLDJnTVx07d+4cO+20U6xevbre9tWrV0e3bt2a9RpFRUVx8MEHx4oVK5pcU1xcHMXFxY3u6x+5Pj1Jl36np7q2IKo35c/AEhF5fW44t9Ol3+/TB3YkM2Tz6Um69Dtd+TZH5vO54dxOl36/L5c+5HRz+zZt2sShhx4aCxcurNtWW1sbCxcurHcF2JZs2rQpnn322dh9991zeWsAAAAAyElOV3xFREyZMiXGjx8fhx12WAwaNCiuvfba2LBhQ0yYMCEiIsaNGxc9evSIGTNmRETEpZdeGkcccUTsueee8fbbb8cVV1wRr7zySpx55pk79kgAAAAA4ANyDr6+8IUvxJtvvhmXXHJJVFRUxEEHHRQLFiyIrl27RkTEq6++GoWF719I9re//S3OOuusqKioiN122y0OPfTQePTRR2O//fbbcUcBAAAAAB+Sc/AVETFp0qSYNGlSo88tXry43uNrrrkmrrnmmm15GwAAAADYZjnd4wsAAAAA8oXgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJgi8AAAAAMknwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBMEnwBAAAAkEmCLwAAAAAySfAFAAAAQCYJvgAAAADIJMEXAAAAAJkk+AIAAAAgkwRfAAAAAGSS4AsAAACATBJ8AQAAAJBJgi8AAAAAMknwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBM2qbg64Ybboi+fftG27Zt4/DDD48nnnhii+vvvvvu2GeffaJt27YxYMCAmD9//jYVCwAAAADNlXPwdeedd8aUKVNi6tSp8dRTT8XAgQOjtLQ03njjjUbXP/roozF27NiYOHFiPP300zF69OgYPXp0LF++fLuLBwAAAICm5Bx8XX311XHWWWfFhAkTYr/99oubbrop2rVrF7fddluj66+77ro47rjj4oILLoh99903LrvssjjkkEPi+uuv3+7iAQAAAKApO+ey+N13340nn3wyLrroorpthYWFMXz48CgvL290n/Ly8pgyZUq9baWlpTFv3rwm36e6ujqqq6vrHldWVkZExNq1a6OmpiaXkjOrpqYmNm7cGG+99VYUFRW1djmZp9/p2dzrnWsKY1NtQWuXk5O33nqrtUvImXM7Xfrd0Lp16yIiIkmSVq6ELDBDbp3PoXTpd7rydY40Q7I1+t1QLjNkTsHXmjVrYtOmTdG1a9d627t27RrPP/98o/tUVFQ0ur6ioqLJ95kxY0ZMnz69wfZ+/frlUi5Aqjpf1doVQP5at25d7Lrrrq1dBnnODAnkIzMkbLvmzJA5BV9pueiii+pdJVZbWxtr166NT3ziE1FQkD/JfUuqqqqKXr16xV/+8pcoKSlp7XIyT7/To9fp0u906XdDSZLEunXronv37q1dChlghtw6n0Pp0u906Xd69Dpd+t1QLjNkTsFX586dY6eddorVq1fX27569ero1q1bo/t069Ytp/UREcXFxVFcXFxvW8eOHXMp9WOjpKTEiZ8i/U6PXqdLv9Ol3/W50osdxQzZfD6H0qXf6dLv9Oh1uvS7vubOkDnd3L5NmzZx6KGHxsKFC+u21dbWxsKFC2Pw4MGN7jN48OB66yMiysrKmlwPAAAAADtCzl91nDJlSowfPz4OO+ywGDRoUFx77bWxYcOGmDBhQkREjBs3Lnr06BEzZsyIiIjzzjsvhg4dGldddVWMGjUqZs+eHUuXLo1bbrllxx4JAAAAAHxAzsHXF77whXjzzTfjkksuiYqKijjooINiwYIFdTewf/XVV6Ow8P0LyYYMGRKzZs2K733ve/Gd73wn+vfvH/PmzYsDDjhgxx3Fx1BxcXFMnTq1weX8tAz9To9ep0u/06XfQGvzOZQu/U6XfqdHr9Ol39unIPH3wwEAAADIoJzu8QUAAAAA+ULwBQAAAEAmCb4AAAAAyCTBFwAAAACZJPjKE2vXro0vfelLUVJSEh07doyJEyfG+vXrm7VvkiQxcuTIKCgoiHnz5rVsoRmRa7/Xrl0b3/jGN2LvvfeOXXbZJXr37h3nnntuVFZWplh1/rjhhhuib9++0bZt2zj88MPjiSee2OL6u+++O/bZZ59o27ZtDBgwIObPn59SpdmQS79/+tOfxlFHHRW77bZb7LbbbjF8+PCt/vtQX67n92azZ8+OgoKCGD16dMsWCHysmCHTZYZsWWbIdJkh02WGbDmCrzzxpS99KZ577rkoKyuL3/72t7FkyZI4++yzm7XvtddeGwUFBS1cYbbk2u/XX389Xn/99bjyyitj+fLlMXPmzFiwYEFMnDgxxarzw5133hlTpkyJqVOnxlNPPRUDBw6M0tLSeOONNxpd/+ijj8bYsWNj4sSJ8fTTT8fo0aNj9OjRsXz58pQrz0+59nvx4sUxduzYWLRoUZSXl0evXr3i2GOPjddeey3lyvNTrv3e7OWXX47zzz8/jjrqqJQqBT4uzJDpMkO2HDNkusyQ6TJDtrCEj7w//vGPSUQk//M//1O37b777ksKCgqS1157bYv7Pv3000mPHj2SVatWJRGRzJ07t4WrzX/b0+8Puuuuu5I2bdokNTU1LVFm3ho0aFByzjnn1D3etGlT0r1792TGjBmNrj/11FOTUaNG1dt2+OGHJ1/5yldatM6syLXfH/bee+8lHTp0SO64446WKjFTtqXf7733XjJkyJDk1ltvTcaPH5+cdNJJKVQKfByYIdNlhmxZZsh0mSHTZYZsWa74ygPl5eXRsWPHOOyww+q2DR8+PAoLC+Pxxx9vcr+NGzfGaaedFjfccEN069YtjVIzYVv7/WGVlZVRUlISO++8c0uUmZfefffdePLJJ2P48OF12woLC2P48OFRXl7e6D7l5eX11kdElJaWNrme921Lvz9s48aNUVNTE506dWqpMjNjW/t96aWXRpcuXfx2H9jhzJDpMkO2HDNkusyQ6TJDtjyfpnmgoqIiunTpUm/bzjvvHJ06dYqKioom9/vmN78ZQ4YMiZNOOqmlS8yUbe33B61ZsyYuu+yyZn+V4ONizZo1sWnTpujatWu97V27do3nn3++0X0qKioaXd/cf4uPs23p94f9v//3/6J79+4NBkca2pZ+P/LII/Gzn/0sli1blkKFwMeNGTJdZsiWY4ZMlxkyXWbIlueKr1b07W9/OwoKCrb409wPlg/7zW9+Ew899FBce+21O7boPNaS/f6gqqqqGDVqVOy3334xbdq07S8cWsnll18es2fPjrlz50bbtm1bu5zMWbduXXz5y1+On/70p9G5c+fWLgfII2bIdJkhITdmyJZlhsydK75a0be+9a04/fTTt7jmU5/6VHTr1q3BTe3ee++9WLt2bZOXnz/00EPx0ksvRceOHettP/nkk+Ooo46KxYsXb0fl+akl+73ZunXr4rjjjosOHTrE3Llzo6ioaHvLzpTOnTvHTjvtFKtXr663ffXq1U32tlu3bjmt533b0u/Nrrzyyrj88svjwQcfjAMPPLAly8yMXPv90ksvxcsvvxwnnnhi3bba2tqI+McVAi+88ELsscceLVs0kJfMkOkyQ7Y+M2S6zJDpMkOmoLVvMsbWbb5R5tKlS+u23X///Vu8UeaqVauSZ599tt5PRCTXXXdd8uc//zmt0vPStvQ7SZKksrIyOeKII5KhQ4cmGzZsSKPUvDRo0KBk0qRJdY83bdqU9OjRY4s3Jj3hhBPqbRs8eLAbkzZTrv1OkiT54Q9/mJSUlCTl5eVplJgpufT773//e4PP6ZNOOik5+uijk2effTaprq5Os3Qgg8yQ6TJDtiwzZLrMkOkyQ7YswVeeOO6445KDDz44efzxx5NHHnkk6d+/fzJ27Ni65//6178me++9d/L44483+RrhL/I0W679rqysTA4//PBkwIAByYoVK5JVq1bV/bz33nutdRgfSbNnz06Ki4uTmTNnJn/84x+Ts88+O+nYsWNSUVGRJEmSfPnLX06+/e1v163/wx/+kOy8887JlVdemfzpT39Kpk6dmhQVFSXPPvtsax1CXsm135dffnnSpk2bZM6cOfXO43Xr1rXWIeSVXPv9Yf4iD7CjmSHTZYZsOWbIdJkh02WGbFmCrzzx1ltvJWPHjk3at2+flJSUJBMmTKj3IbJy5cokIpJFixY1+RqGlubLtd+LFi1KIqLRn5UrV7bOQXyE/fjHP0569+6dtGnTJhk0aFDy2GOP1T03dOjQZPz48fXW33XXXclee+2VtGnTJtl///2T3/3udylXnN9y6XefPn0aPY+nTp2afuF5Ktfz+4MMLcCOZoZMlxmyZZkh02WGTJcZsuUUJEmStOyXKQEAAAAgff6qIwAAAACZJPgCAAAAIJMEXwAAAABkkuALAAAAgEwSfAEAAACQSYIvAAAAADJJ8AUAAABAJgm+AAAAAMgkwRcAAAAAmST4AgAAACCTBF8AAAAAZJLgCwAAAIBM+v8Yvbf9gvoJBAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(15, 5))\n", "\n", "df.hist(\n", " column=\"line_score\",\n", " ax=axes[0],\n", ")\n", "\n", "df.hist(\n", " column=\"block_score\",\n", " ax=axes[1],\n", ")" ] }, { "cell_type": "code", "execution_count": 67, "id": "3b8d339e-4452-4e2c-823c-0f69f6eb4805", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 promptpredictionlabelblock_scoreline_score
0 attentions=all_attentions,\n", " cross_attentions=all_cross_attentions,\n", " )\n", "\n", "\n", "class T5ForConditionalGeneration(T5PreTrainedModel):\n", " def __init__(self, config: T5Config, weights):\n", " super().__init__(config)\n", " self.model_dim = config.d_model\n", "\n", " try:\n", " self.shared = TensorParallelEmbedding(prefix=\"shared\", weights=weights)\n", " except RuntimeError:\n", " self.shared = TensorParallelEmbedding(prefix=\"encoder.embed_tokens\", weights=weights)\n", "\n", " encoder_config = copy.deepcopy(config)\n", " encoder_config.is_decoder = False\n", " encoder_config.use_cache = False\n", " encoder_config.is_encoder_decoder = False\n", " self.encoder = /*\n", " * Copyright (c) 2008-2021, Hazelcast, Inc. All Rights Reserved.\n", " *\n", " * Licensed under the Apache License, Version 2.0 (the \"License\");\n", " * you may not use this file except in compliance with the License.\n", " * You may obtain a copy of the License at\n", " *\n", " * http://www.apache.org/licenses/LICENSE-2.0\n", " *\n", " * Unless required by applicable law or agreed to in writing, software\n", " * distributed under the License is distributed on an \"AS IS\" BASIS,T5Stack(\n", " config=encoder_config,\n", " prefix=\"encoder\",\n", " weights=weights,\n", " embed_tokens=self.shared,\n", " )0.0000000.000000
1 past_present_indices,\n", " past_key_values: Optional[torch.Tensor] = None,\n", " pre_allocate_past_size: Optional[int] = None,\n", " lm_head_indices: Optional[torch.Tensor] = None,\n", " ):\n", " hidden_states, present = self.gpt_neox(\n", " input_ids,\n", " position_ids,\n", " start_seq,\n", " end_seq,\n", " start_seq_q,\n", " end_seq_q,\n", " max_s,\n", " past_present_indices,\n", " past_key_values,\n", " pre_allocate_past_size,\n", " )\n", " if lm_head_indices is not None:\n", " hidden_states = hidden_states[lm_head_indices]\n", " logits = /*\n", " * Copyright (c) 2008-2021, Hazelcast, Inc. All Rights Reserved.\n", " *\n", " * Licensed under the Apache License, Version 2.0 (the \"License\");\n", " * you may not use this file except in compliance with the License.\n", " * You may obtain a copy of the License at\n", " *\n", " * http://www.apache.org/licenses/LICENSE-2.0\n", " *\n", " * Unless required by applicable law or agreed to in writing, software\n", " * distributed under the License is distributed on an \"AS IS\" BASIS,self.embed_out(hidden_states)0.0000000.000000
2 if not isinstance(serialized_data, List):\n", " serialized_data = [serialized_data]\n", " if not isinstance(snapshot_data, List):\n", " snapshot_data = [snapshot_data]\n", "\n", " return len(snapshot_data) == len(serialized_data) and all(\n", " [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]\n", " )\n", "\n", "\n", "class LauncherHandle:\n", " def __init__(self, port: int):\n", " self.client = AsyncClient(f\"http://localhost:{port}\")\n", "\n", " def _inner_health(self):\n", " raise NotImplementedError\n", "\n", " async def health(self, timeout: int = 60):\n", " assert timeout > 0\n", " for _ in /*\n", " * Copyright (c) 2008-2021, Hazelcast, Inc. All Rights Reserved.\n", " *\n", " * Licensed under the Apache License, Version 2.0 (the \"License\");\n", " * you may not use this file except in compliance with the License.\n", " * You may obtain a copy of the License at\n", " *\n", " * http://www.apache.org/licenses/LICENSE-2.0\n", " *\n", " * Unless required by applicable law or agreed to in writing, software\n", " * distributed under the License is distributed on an \"AS IS\" BASIS,range(timeout)0.0000000.000000
\n" ], "text/plain": [ "" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "codeStyle = {\n", " \"selector\": \"td\",\n", " \"props\": [\n", " (\"white-space\", \"pre\"),\n", " (\"font-family\", \"monospace\"),\n", " (\"text-align\", \"left\"),\n", " (\"max-width\", \"400px\"),\n", " (\"overflow-x\", \"scroll\"),\n", " ],\n", "}\n", "\n", "df.style.set_table_styles(\n", " {\n", " \"prompt\": [codeStyle],\n", " \"prediction\": [codeStyle],\n", " \"label\": [codeStyle],\n", " }\n", ")" ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" } }, "nbformat": 4, "nbformat_minor": 5 }