{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "atmospheric-upper", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import nest_asyncio\n", "nest_asyncio.apply()\n", "import stan" ] }, { "cell_type": "code", "execution_count": 44, "id": "warming-fields", "metadata": {}, "outputs": [], "source": [ "# Mixture parameters\n", "\n", "N=1000\n", "\n", "K=4\n", "p = np.array([0.25,0.25,0.25,0.25])\n", "\n", "mu= np.array([1,3,5,7])\n", "\n", "sigma = np.array([0.5,0.5,0.5,0.5])\n" ] }, { "cell_type": "code", "execution_count": 45, "id": "academic-funds", "metadata": {}, "outputs": [], "source": [ "#sample the mixture to create artificial data\n", "\n", "np.random.seed(0)\n", "\n", "Zt = np.random.randint(0,K,N)\n", "\n", "X=np.zeros(N)\n", "\n", "for k in range(K):\n", " pts=np.where(Zt==k)[0]\n", " npts=len(pts)\n", " X[pts]=np.random.normal(mu[k],sigma[k],npts)" ] }, { "cell_type": "code", "execution_count": 46, "id": "lined-giving", "metadata": {}, "outputs": [], "source": [ "# Prior parameters\n", "mu0 = np.array([0.,0.,0.,0.])\n", "sigma0 = np.array([10.,10.,10.,10.])\n", "\n", "c = np.array([1.,1.,1.,1.])\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 47, "id": "innovative-sunset", "metadata": {}, "outputs": [], "source": [ "import stan" ] }, { "cell_type": "code", "execution_count": 48, "id": "internal-malawi", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Building...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "Building: found in cache, done.Messages from stanc:\n", "Warning in '/tmp/httpstan_4zz45_p4/model_6pyh3mdl.stan', line 6, column 2: Declaration\n", " of arrays by placing brackets after a variable name is deprecated and\n", " will be removed in Stan 2.32.0. Instead use the array keyword before the\n", " type. This can be changed automatically using the auto-format flag to\n", " stanc\n", "Warning: The parameter mu has no priors. This means either no prior is\n", " provided, or the prior(s) depend on data variables. In the later case,\n", " this may be a false positive.\n", "Sampling: 0%\n", "Sampling: 0% (1/8000)\n", "Sampling: 0% (2/8000)\n", "Sampling: 0% (3/8000)\n", "Sampling: 0% (4/8000)\n", "Sampling: 3% (203/8000)\n", "Sampling: 5% (402/8000)\n", "Sampling: 9% (702/8000)\n", "Sampling: 11% (901/8000)\n", "Sampling: 14% (1101/8000)\n", "Sampling: 15% (1200/8000)\n", "Sampling: 19% (1500/8000)\n", "Sampling: 22% (1800/8000)\n", "Sampling: 25% (2000/8000)\n", "Sampling: 28% (2200/8000)\n", "Sampling: 42% (3400/8000)\n", "Sampling: 61% (4900/8000)\n", "Sampling: 79% (6300/8000)\n", "Sampling: 100% (8000/8000)\n", "Sampling: 100% (8000/8000), done.\n", "Messages received during sampling:\n", " Gradient evaluation took 0.000273 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 2.73 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 0.000245 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 2.45 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 0.000381 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 3.81 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 0.00049 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 4.9 seconds.\n", " Adjust your expectations accordingly!\n" ] } ], "source": [ "cluster_code = \"\"\"\n", "\n", "data {\n", " int N; // number of points\n", " int K; // number of components\n", " real X[N]; \n", " vector [K] sigma;\n", " vector [K] sigma0;\n", " vector [K] mu0;\n", " vector [K] c;\n", "}\n", "\n", "parameters {\n", " vector[K] mu; // population treatment effect\n", " simplex [K] p;\n", "}\n", "\n", "model {\n", " \n", " for (k in 1:K) { \n", " target += log(p[k]) + normal_lpdf(mu[k] | mu0[k], sigma0[k]); // prior log-density\n", " }\n", "\n", " for (k in 1:K) { \n", " target += dirichlet_lpdf(p | c); // prior log-density\n", " }\n", "\n", " \n", " \n", " for (n in 1:N) {\n", " vector[K] lps;\n", " for (k in 1:K) {\n", " lps[k] = normal_lpdf(X[n] | mu[k], sigma[k]);\n", " }\n", " target += log_sum_exp(lps);\n", " }\n", "}\n", "\n", "\n", "\"\"\"\n", "\n", "cluster_data = {\"N\": N,\"K\": 4, \"X\": X, \"mu0\": mu0, \"sigma0\": sigma0, \"sigma\": sigma, \"c\": c}\n", "\n", "\n", "posterior = stan.build(cluster_code, data=cluster_data, random_seed=1)\n", "\n", "fit = posterior.sample(num_chains=4, num_samples=1000)\n", "\n", "mu = fit[\"mu\"] " ] }, { "cell_type": "code", "execution_count": 49, "id": "threatened-guarantee", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.94646088 2.93110097 5.01200022 7.01476528]\n", "[0.25 0.25 0.25 0.25]\n", "[[0.97063669 0.89666359 0.91029708 ... 0.9866595 0.98852443 0.91273853]\n", " [2.96792948 2.87518591 2.88268633 ... 2.93130679 2.90934216 2.91433174]\n", " [5.02782928 4.99904749 4.93856443 ... 4.99127792 4.96533382 4.99514043]\n", " [7.03952579 6.98701906 7.01649147 ... 6.94257406 7.0003452 7.04601038]]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "print(np.mean(np.sort(mu,axis=0),axis=1))\n", "print(p)\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "plt.plot(np.sort(mu,axis=0).T)\n", "\n", "print(np.sort(mu,axis=0))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "racial-external", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }