package cc.mallet.grmm.test;

import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.Variable;
import cc.mallet.types.tests.TestSerializable;
import java.io.IOException;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:cc/mallet/grmm/test/TestAssignment.class */
public class TestAssignment extends TestCase {
    private Variable[] vars;

    public TestAssignment(String str) {
        super(str);
    }

    @Override // junit.framework.TestCase
    protected void setUp() throws Exception {
        this.vars = new Variable[]{new Variable(2), new Variable(2)};
    }

    public void testSimple() {
        Assignment assignment = new Assignment(this.vars, new int[]{1});
        assertEquals(1, assignment.get(this.vars[0]));
        assertEquals(0, assignment.get(this.vars[1]));
        assertEquals(new Integer(0), assignment.getObject(this.vars[1]));
    }

    public void testScale() {
        Assignment assignment = new Assignment(this.vars, new int[]{1});
        assignment.addRow(this.vars, new int[]{1});
        assignment.addRow(this.vars, new int[]{1, 1});
        Assignment assignment2 = new Assignment(this.vars, new int[]{1});
        assignment.normalize();
        assertEquals(0.666666d, assignment.value(assignment2), 1.0E-5d);
    }

    public void testScaleMarginalize() {
        Assignment assignment = new Assignment(this.vars, new int[]{1});
        assignment.addRow(this.vars, new int[]{1});
        assignment.addRow(this.vars, new int[]{1, 1});
        assignment.normalize();
        assertEquals(0.666666d, assignment.marginalize(this.vars[1]).value(new Assignment(this.vars[1], 0)), 1.0E-5d);
    }

    public void testSerialization() throws IOException, ClassNotFoundException {
        Assignment assignment = new Assignment(this.vars, new int[]{1});
        Assignment assignment2 = (Assignment) TestSerializable.cloneViaSerialization(assignment);
        assertEquals(2, assignment2.numVariables());
        assertEquals(1, assignment2.numRows());
        assertEquals(1, assignment.get(this.vars[0]));
        assertEquals(0, assignment.get(this.vars[1]));
    }

    public void testMarginalize() {
        Assignment assignment = new Assignment();
        assignment.addRow(this.vars, new int[]{1, 1});
        assignment.addRow(this.vars, new int[]{1});
        Assignment assignment2 = (Assignment) assignment.marginalize(this.vars[0]);
        assertEquals(2, assignment2.numRows());
        assertEquals(1, assignment2.size());
        assertEquals(this.vars[0], assignment2.getVariable(0));
        assertEquals(1, assignment.get(0, this.vars[0]));
        assertEquals(1, assignment.get(1, this.vars[0]));
    }

    public void testMarginalizeOut() {
        Assignment assignment = new Assignment();
        assignment.addRow(this.vars, new int[]{1, 1});
        assignment.addRow(this.vars, new int[]{1});
        Assignment assignment2 = (Assignment) assignment.marginalizeOut(this.vars[1]);
        assertEquals(2, assignment2.numRows());
        assertEquals(1, assignment2.size());
        assertEquals(this.vars[0], assignment2.getVariable(0));
        assertEquals(1, assignment.get(0, this.vars[0]));
        assertEquals(1, assignment.get(1, this.vars[0]));
    }

    public void testUnion() {
        Assignment assignment = new Assignment();
        assignment.addRow(new Variable[]{this.vars[0]}, new int[]{1});
        Assignment assignment2 = new Assignment();
        assignment2.addRow(new Variable[]{this.vars[1]}, new int[1]);
        Assignment union = Assignment.union(assignment, assignment2);
        assertEquals(1, union.numRows());
        assertEquals(2, union.numVariables());
        assertEquals(1, union.get(0, this.vars[0]));
        assertEquals(0, union.get(0, this.vars[1]));
    }

    public void testMultiRow() {
        Assignment assignment = new Assignment();
        assignment.addRow(this.vars, new int[]{1, 1});
        assignment.addRow(this.vars, new int[]{1});
        assertEquals(2, assignment.numRows());
        assertEquals(1, assignment.get(0, this.vars[1]));
        assertEquals(0, assignment.get(1, this.vars[1]));
        try {
            assignment.get(this.vars[1]);
            fail();
        } catch (IllegalArgumentException e) {
        }
    }

    public void testSetRow() {
        Assignment assignment = new Assignment();
        assignment.addRow(this.vars, new int[]{1, 1});
        assignment.addRow(this.vars, new int[]{1});
        assertEquals(1, assignment.get(0, this.vars[0]));
        assignment.setRow(0, new int[2]);
        assertEquals(2, assignment.numRows());
        assertEquals(0, assignment.get(0, this.vars[0]));
        assertEquals(0, assignment.get(0, this.vars[1]));
        assertEquals(1, assignment.get(1, this.vars[0]));
        assertEquals(0, assignment.get(1, this.vars[1]));
    }

    public void testSetRowFromAssn() {
        Assignment assignment = new Assignment();
        assignment.addRow(this.vars, new int[]{1, 1});
        assignment.addRow(this.vars, new int[]{1});
        assertEquals(1, assignment.get(0, this.vars[0]));
        Assignment assignment2 = new Assignment();
        assignment2.addRow(this.vars, new int[2]);
        assignment.setRow(0, assignment2);
        assertEquals(2, assignment.numRows());
        assertEquals(0, assignment.get(0, this.vars[0]));
        assertEquals(0, assignment.get(0, this.vars[1]));
        assertEquals(1, assignment.get(1, this.vars[0]));
        assertEquals(0, assignment.get(1, this.vars[1]));
    }

    public void testSetValue() {
        Assignment assignment = new Assignment();
        assignment.addRow(this.vars, new int[]{1, 1});
        assignment.setValue(this.vars[0], 0);
        assertEquals(1, assignment.numRows());
        assertEquals(0, assignment.get(0, this.vars[0]));
        assertEquals(1, assignment.get(0, this.vars[1]));
    }

    public void testSetValueDup() {
        Assignment assignment = new Assignment();
        assignment.addRow(this.vars, new int[]{1, 1});
        Assignment assignment2 = (Assignment) assignment.duplicate();
        assignment2.setValue(this.vars[0], 0);
        assertEquals(1, assignment2.numRows());
        assertEquals(0, assignment2.get(0, this.vars[0]));
        assertEquals(1, assignment2.get(0, this.vars[1]));
    }

    public void testSetValueExpand() {
        Assignment assignment = new Assignment();
        assignment.addRow(this.vars, new int[2]);
        Variable variable = new Variable(2);
        assignment.setValue(variable, 1);
        assertEquals(3, assignment.size());
        assertEquals(0, assignment.get(this.vars[0]));
        assertEquals(0, assignment.get(this.vars[1]));
        assertEquals(1, assignment.get(variable));
    }

    public void testAsTable() {
        Assignment assignment = new Assignment();
        assignment.addRow(this.vars, new int[]{1, 1});
        assignment.addRow(this.vars, new int[]{1});
        assignment.addRow(this.vars, new int[]{1});
        assertTrue(new TableFactor(this.vars, new double[]{0.0d, 0.0d, 2.0d, 1.0d}).almostEquals(assignment.asTable()));
    }

    public void testAddRowMixed() {
        Assignment assignment = new Assignment();
        assignment.addRow(this.vars, new int[]{1, 1});
        assignment.addRow(this.vars, new int[]{1});
        Assignment assignment2 = new Assignment();
        assignment2.addRow(new Variable[]{this.vars[1], this.vars[0]}, new int[]{0, 1});
        assignment.addRow(assignment2);
        assertTrue(new TableFactor(this.vars, new double[]{0.0d, 0.0d, 2.0d, 1.0d}).almostEquals(assignment.asTable()));
    }

    public static Test suite() {
        return new TestSuite((Class<? extends TestCase>) TestAssignment.class);
    }

    public static void main(String[] strArr) throws Exception {
        TestSuite testSuite;
        if (strArr.length > 0) {
            testSuite = new TestSuite();
            for (String str : strArr) {
                testSuite.addTest(new TestAssignment(str));
            }
        } else {
            testSuite = (TestSuite) suite();
        }
        TestRunner.run(testSuite);
    }
}
