Een reeks vectoren ($$M$$ dimensies) wordt gegroepeerd in 2 subgroepen, namelijk de subgroep 0 en de subgroep 1. We beschikken hiertoe over een verzameling van $$N$$ voorbeelden, waarvan we dus de subgroep kennen. We noemen $$\mathbf{x}_i$$ het voorbeeld met rangnummer $$i$$, en de bijhorende subgroep noemen we $$y_i$$. $$ $$

We willen nu voor een nieuwe vector bepalen tot welke subgroep die behoort. Om dit op een wiskundig elegante manier te doen, construeren we uit elke kolomvector $$\mathbf{x}_i$$ de kolomvector $$\mathbf{u}_i$$, namelijk:
$$ \begin{align} \mathbf{u} &= \begin{bmatrix} 1 \\ x_{0} \\ x_{1} \\ \vdots \\ x_{M-1} \end{bmatrix} \end{align} $$

Deze kolomvector telt dus $$M+1$$ componenten. We vermoeden dat er voor de voorbeelden een $$(M+1)$$-dimensionale vector $$\mathbf{w}$$ kan gevonden worden, zodat de klasse $$y_i$$ voor het merendeel van de voorbeelden $$\mathbf{u}_i$$ gegeven wordt door:
$$ y_i = \Big\{ \begin{matrix} 1 & \text{voor} & \mathbf{w}^T \mathbf{u}_i \gt 0 \\ 0 & \text{anders} & \\ \end{matrix} $$

Om de kolomvector $$\mathbf{w}$$ te vinden, gaan we iteratief tewerk. Noem $$\mathbf{w}_i$$ de waarde voor de kolomvector $$\mathbf{w}$$ uit de iteratie met rangnummer $$i$$. $$ $$

Programmeer de functie bepaal_w() met als argumenten:

Het resultaat van de functie is de 1D-rij die de kolomvector $$\mathbf{w}_k$$, berekend zoals hierboven aangegeven, voorstelt.

Voorbeeld

Merk op dat het Dodonascript je resultaat omzet naar een lijst. Het resultaat van je functie moet wel degelijk een 1D NumPy-rij zijn. De numerieke waarden worden ook afgekapt op 4 decimalen.

bepaal_w(np.array([[5.0, 0.2, -9.3, -0.4, 1.0], [-6.4, 5.2, -5.8, 5.7, 0.0], [-2.4, 8.0, -6.3, 7.8, 0.0], [-1.3, 1.9, -2.1, -2.8, 1.0], [-7.9, 0.9, -3.3, -5.0, 1.0], [-3.8, 4.7, -4.7, 0.9, 0.0], [9.9, -8.0, -1.6, 7.8, 0.0], [-1.9, 9.5, 3.3, -1.4, 0.0], [-0.1, 2.1, -1.9, 7.6, 0.0], [-9.2, -7.6, 3.7, -7.6, 1.0]]), 3)
#[-0.7281, 0.1731, -0.2478, -0.3975, -0.8908]

bepaal_w(np.array([[-9.3, -0.1, -2.0, -7.2, 1.0], [6.6, 0.3, -9.4, 3.8, 0.0], [0.1, 2.2, -4.4, -4.6, 1.0], [-2.0, 3.0, -7.1, -9.9, 1.0], [9.5, -4.4, 9.9, 1.0, 1.0], [-0.3, 3.9, -3.7, 0.1, 1.0], [5.6, 2.1, -5.0, 0.2, 0.0]]), 3)
#[3.2679, -0.4588, 0.7043, 0.8229, -0.3266]